def main(args):

    savedir = '/home/shyam.nandan/NewExp/final_code/save/' + args.savedir  #change path here

    if not os.path.exists(savedir):
        os.makedirs(savedir)

    with open(savedir + '/opts.txt', "w") as myfile:
        myfile.write(str(args))

    rmodel = UNet()
    rmodel = torch.nn.DataParallel(rmodel).cuda()
    pretrainedEnc = torch.nn.DataParallel(ERFNet_imagenet(1000))
    pretrainedEnc.load_state_dict(
        torch.load(args.pretrainedEncoder)['state_dict'])
    pretrainedEnc = next(pretrainedEnc.children()).features.encoder
    model = Net(NUM_CLASSES)
    model = fill_weights(model, pretrainedEnc)
    model = torch.nn.DataParallel(model).cuda()
    #model = train(args, rmodel, model, False)

    PATH = '/home/shyam.nandan/NewExp/final_code/results/CB_iFL/rmodel_best.pth'
    rmodel.load_state_dict(torch.load(PATH))

    PATH = '/home/shyam.nandan/NewExp/final_code/results/CB_iFL/model_best.pth'

    model.load_state_dict(torch.load(PATH))

    model = train(args, rmodel, model, False)
def test():
    device = torch.device(conf.cuda if torch.cuda.is_available() else "cpu")
    test_dataset = Testinging_Dataset(conf.data_path_test,
                                      conf.test_noise_param,
                                      conf.crop_img_size)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    print('Loading model from: {}'.format(conf.model_path_test))
    model = UNet(in_channels=conf.img_channel, out_channels=conf.img_channel)
    print('loading model')
    model.load_state_dict(torch.load(conf.model_path_test))
    model.eval()
    model.to(device)
    result_dir = conf.denoised_dir
    if not os.path.exists(result_dir):
        os.mkdir(result_dir)
    for batch_idx, (source, img_cropped) in enumerate(test_loader):
        source_img = tvF.to_pil_image(source.squeeze(0))
        img_truth = img_cropped.squeeze(0).numpy().astype(np.uint8)
        source = source.to(device)
        denoised_img = model(source).detach().cpu()

        img_name = test_loader.dataset.image_list[batch_idx]

        denoised_result = tvF.to_pil_image(
            torch.clamp(denoised_img.squeeze(0), 0, 1))
        fname = os.path.splitext(img_name)[0]

        source_img.save(os.path.join(result_dir, f'{fname}-noisy.png'))
        denoised_result.save(os.path.join(result_dir, f'{fname}-denoised.png'))
        io.imsave(os.path.join(result_dir, f'{fname}-ground_truth.png'),
                  img_truth)
Exemplo n.º 3
0
def main(args):
    if args.cuda and not torch.cuda.is_available():
        raise Exception("No GPU found, please run without --cuda")
    model = UNet(n_channels=args.colordim, n_classes=args.num_class)
    model2 = UNet(n_channels=args.colordim2, n_classes=args.num_class2)
    if args.cuda:
        model = model.cuda()
        model2 = model2.cuda()
    model.load_state_dict(torch.load(args.pretrain_net))
    model2.load_state_dict(torch.load(args.pretrain_net2))
    model.eval()
    model2.eval()
    predDataset = generateDataset(args.pre_root_dir,
                                  args.img_size,
                                  args.colordim,
                                  isTrain=False)
    predLoader = DataLoader(dataset=predDataset,
                            batch_size=args.predictbatchsize,
                            num_workers=args.threads)
    with torch.no_grad():
        cm_w = np.zeros((2, 2))
        for batch_idx, (batch_x, batch_name) in enumerate(predLoader):
            batch_x = batch_x
            if args.cuda:
                batch_x = batch_x.float().cuda()

            out1 = model(batch_x)
            prediction2 = torch.cat((batch_x, out1), 1)
            out = model2(prediction2)
            pred_prop, pred_label = torch.max(out, 1)
            pred_label_np = pred_label.cpu().numpy()
            for id in range(len(batch_name)):
                predLabel_filename = args.preDir + '/' + batch_name[id] + '.png'

                pred_label_single = pred_label_np[id, :, :]
                label_filename = args.label_root_dir + batch_name[id] + '.png'
                label = io.imread(label_filename)
                cm = confusion_matrix(label.ravel(), pred_label_single.ravel())
                pred_label_single = np.where(pred_label_single > 0, 255, 0)
                print(np.max(pred_label_single))
                print(batch_name[id])
                if (np.max(pred_label_single) > 0):
                    io.imsave(predLabel_filename,
                              pred_label_single.astype(np.uint8))
                    #else:
                    #io.imsave(predLabel_filename, pred_label_single.astype(np.int32))
                    cm_w = cm_w + cm
                #OA_s, F1_s, IoU_s = evaluate(cm)
                #print('OA_s = ' + str(OA_s) + ', F1_s = ' + str(F1_s) + ', IoU = ' + str(IoU_s))

        print(cm_w)
        OA_w, F1_w, IoU_w = evaluate(cm_w)
        print('OA_w = ' + str(OA_w) + ', F1_w = ' + str(F1_w) + ', IoU = ' +
              str(IoU_w))
Exemplo n.º 4
0
def submit_mnms(model_path, input_data_directory, output_data_directory,
                device):

    data_paths = load_path(input_data_directory)

    net = UNet(n_channels=1, n_classes=4, bilinear=True)
    net.load_state_dict(torch.load(model_path, map_location=device))
    net.to(device)

    for path in data_paths:
        ED_np, ES_np = load_phase(path)  # HxWxF
        ED_masks = []
        ES_masks = []
        for i in range(ED_np.shape[2]):
            img_np = ED_np[:, :, i]
            img_tensor = pre_transform(img_np)
            img_tensor = img_tensor.to(device)

            mask = predict_img(net, img_tensor)

            mask = post_transform(img_np, mask[0:3, :, :])
            ED_masks.append(mask)

        for i in range(ES_np.shape[2]):
            img_np = ES_np[:, :, i]
            img_tensor = pre_transform(img_np)
            img_tensor = img_tensor.to(device)

            mask = predict_img(net, img_tensor)

            mask = post_transform(img_np, mask[0:3, :, :])
            ES_masks.append(mask)

        ED_masks = np.concatenate(ED_masks, axis=2)
        ES_masks = np.concatenate(ES_masks, axis=2)
        save_phase(ED_masks, ES_masks, output_data_directory, path)
Exemplo n.º 5
0
import os
import cv2
from torchvision import transforms

from unet_model import UNet

if __name__ == "__main__":
    # 选择设备,有cuda用cuda,没有就用cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 加载网络,图片单通道,分类为1。
    net = UNet(n_channels=1, n_classes=1)
    # 将网络拷贝到deivce中
    net.to(device=device)
    # 加载模型参数
    # net.load_state_dict(torch.load('best_model_100X.pth', map_location=device))
    net.load_state_dict(torch.load('best_model.pth', map_location=device))
    # 测试模式
    net.eval()
    # 读取所有图片路径
    tests_path = glob.glob('dataS/test/*.png')
    # tests_path = glob.glob('data100X/test/*.png')
    # 遍历素有图片
    for test_path in tests_path:
        # 保存结果地址
        save_res_path = test_path.split('.')[0] + '_res.png'
        # 读取图片
        img = cv2.imread(test_path)
        # 转为灰度图
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        # 转为batch为1,通道为1,大小为512*512的数组
        img = img.reshape(1, 1, img.shape[0], img.shape[1])
Exemplo n.º 6
0
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
import tqdm
import cv2
from unet_model import UNet

#device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
best_model_name = 'best_model.pt'
best_model = torch.load(best_model_name)

model = UNet()
model.load_state_dict(best_model['state_dict'])
model.eval()
model.to(device)

test_dir = '../data/poster/images/'
out_dir = '../data/poster/model/'
test_images = [os.path.join(test_dir, x) for x in os.listdir(test_dir)]

counter = 0
i = 0
for i in range(len(test_images)):
    test_image_one = test_images[i]
    #if 'post' not in test_image_one:
    #    i += 1
    #    continue
    #counter += 1
    #i += 1
Exemplo n.º 7
0
    # console printing redirection
    logfile = f'logs/train/{args.dataset_name}_{args.epochs}ep_{args.batch_size}bs.log'
    sys.stdout = Logger(logfile)
    print(args)
    # tensorboard writer
    # writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')

    checkpoint_dir = args.checkpoints_dir + '/' + args.dataset_name
    os.makedirs(checkpoint_dir, exist_ok=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    ## --- Set up model
    net = UNet(n_channels=1, n_classes=1).to(
        device)  # we are using gray images and only have one labelled class
    if args.pretrained_weights:
        net.load_state_dict(torch.load(args.pretrained_weights))
        print(f'Pretrained weights loaded from {args.load}')
    print(
        f' Using device {device}\n Network:\n \t{net.n_channels} input channels\n \t{net.n_classes} output channels (classes)\n \t{"Bilinear" if net.bilinear else "Dilated conv"} upscaling'
    )
    # faster convolutions, but more memory
    # cudnn.benchmark = True

    ## --- Set up data
    imgs_dir = 'data/imgs/' + args.dataset_name
    masks_dir = 'data/masks/' + args.dataset_name
    print(f'imgs_dir: {imgs_dir} masks_dir: {masks_dir}')
    dataset = BasicDataset(imgs_dir, masks_dir, args.down_scale)
    n_val = int(len(dataset) * args.valid_ratio)
    n_train = len(dataset) - n_val
    print(f'n_val: {n_val} n_train: {n_train}')
Exemplo n.º 8
0
        test_in_path=
        "/home/star/0_code_lhj/DL-SIM-github/TESTING_DATA/microtuble/HE_X2/",
        transform=ToTensor(),
        img_type='tif',
        in_size=256)
    test_dataloader = torch.utils.data.DataLoader(
        SRRFDATASET, batch_size=batch_size, shuffle=True,
        pin_memory=True)  # better than for loop

    model = UNet(n_channels=3, n_classes=1)

    print("{} paramerters in total".format(
        sum(x.numel() for x in model.parameters())))
    model.cuda(cuda)
    model.load_state_dict(
        torch.load(
            "/home/star/0_code_lhj/DL-SIM-github/MODELS/UNet_SIM3_microtubule.pkl"
        ))
    model.eval()

    for batch_idx, items in enumerate(test_dataloader):

        image = items['image_in']
        image_name = items['image_name']
        print(image_name[0])
        model.train()

        image = np.swapaxes(image, 1, 3)
        image = np.swapaxes(image, 2, 3)
        image = image.float()
        image = image.cuda(cuda)
Exemplo n.º 9
0
    LE_img = imgread(os.path.join(dir_path_LE, "LE_01.tif"))

    LE_512 = cropImage(LE_img, IMG_SHAPE[0],IMG_SHAPE[1])
    sample_le = {}
    for le_512 in LE_512:
        tiles = crop_prepare(le_512, CROP_STEP, IMG_SIZE)
        for n,img in enumerate(tiles):
            if n not in sample_le:
                sample_le[n] = []
            img = transform.resize(img,(IMG_SIZE*2, IMG_SIZE*2),preserve_range=True,order=3)
            sample_le[n].append(img)

	SNR_model = UNet(n_channels=15, n_classes=15)
	print("{} paramerters in total".format(sum(x.numel() for x in SNR_model.parameters())))
	SNR_model.cuda(cuda)
	SNR_model.load_state_dict(torch.load(SNR_model_path))
	# SNR_model.load_state_dict(torch.load(os.path.join(dir_path,"model","LE_HE_mito","LE_HE_0825.pkl")))
	SNR_model.eval()

	SIM_UNET = UNet(n_channels=15, n_classes=1)
	print("{} paramerters in total".format(sum(x.numel() for x in SIM_UNET.parameters())))
	SIM_UNET.cuda(cuda)
	SIM_UNET.load_state_dict(torch.load(SIM_UNET_model_path))
	# SIM_UNET.load_state_dict(torch.load(os.path.join(dir_path,"model","HE_HER_mito","HE_X2_HER_0825.pkl")))
	SIM_UNET.eval()

    SRRFDATASET = ReconsDataset(
    img_dict=sample_le,
    transform=ToTensor(),
    in_norm = LE_in_norm,
    img_type=".tif",
Exemplo n.º 10
0
class UNetObjPrior(nn.Module):
    """ 
    Wrapper around UNet that takes object priors (gaussians) and images 
    as input.
    """
    def __init__(self, params, depth=5):
        super(UNetObjPrior, self).__init__()
        self.in_channels = 4
        self.model = UNet(1, self.in_channels, depth, cuda=params['cuda'])
        self.params = params
        self.device = torch.device('cuda' if params['cuda'] else 'cpu')

    def forward(self, im, obj_prior):
        x = torch.cat((im, obj_prior), dim=1)
        return self.model(x)

    def train(self, dataloader_train, dataloader_val):

        since = time.time()
        best_loss = float("inf")

        dataloader_train.mode = 'train'
        dataloader_val.mode = 'val'
        dataloaders = {'train': dataloader_train, 'val': dataloader_val}

        optimizer = optim.SGD(self.model.parameters(),
                              momentum=self.params['momentum'],
                              lr=self.params['lr'],
                              weight_decay=self.params['weight_decay'])

        train_logger = LossLogger('train', self.params['batch_size'],
                                  len(dataloader_train),
                                  self.params['out_dir'])

        val_logger = LossLogger('val', self.params['batch_size'],
                                len(dataloader_val), self.params['out_dir'])

        loggers = {'train': train_logger, 'val': val_logger}

        # self.criterion = WeightedMSE(dataloader_train.get_classes_weights(),
        #                              cuda=self.params['cuda'])
        self.criterion = nn.MSELoss()

        for epoch in range(self.params['num_epochs']):
            print('Epoch {}/{}'.format(epoch, self.params['num_epochs'] - 1))
            print('-' * 10)

            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    #scheduler.step()
                    self.model.train()
                else:
                    self.model.eval()  # Set model to evaluate mode

                running_loss = 0.0
                running_corrects = 0

                # Iterate over data.
                samp = 1
                for i, data in enumerate(dataloaders[phase]):
                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                        out = self.forward(data.image, data.obj_prior)
                        loss = self.criterion(out, data.truth)

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    loggers[phase].update(epoch, samp, loss.item())

                    samp += 1

                loggers[phase].print_epoch(epoch)

                # Generate train prediction for check
                if phase == 'train':
                    path = os.path.join(self.params['out_dir'], 'previews',
                                        'epoch_{:04d}.jpg'.format(epoch))
                    data = dataloaders['val'].sample_uniform()
                    pred = self.forward(data.image, data.obj_prior)
                    im_ = data.image[0]
                    truth_ = data.truth[0]
                    pred_ = pred[0, ...]
                    utls.save_tensors(im_, pred_, truth_, path)

                if phase == 'val' and (loggers['val'].get_loss(epoch) <
                                       best_loss):
                    best_loss = loggers['val'].get_loss(epoch)

                loggers[phase].save('log_{}.csv'.format(phase))

                # save checkpoint
                if phase == 'val':
                    is_best = loggers['val'].get_loss(epoch) <= best_loss
                    path = os.path.join(self.params['out_dir'],
                                        'checkpoint.pth.tar')
                    utls.save_checkpoint(
                        {
                            'epoch': epoch + 1,
                            'state_dict': self.model.state_dict(),
                            'best_loss': best_loss,
                            'optimizer': optimizer.state_dict()
                        },
                        is_best,
                        path=path)

    def load_checkpoint(self, path, device='gpu'):

        if (device != 'gpu'):
            checkpoint = torch.load(path,
                                    map_location=lambda storage, loc: storage)
        else:
            checkpoint = torch.load(path)

        self.model.load_state_dict(checkpoint['state_dict'])
Exemplo n.º 11
0
                      action='store_true',
                      dest='gpu',
                      default=False,
                      help='use cuda')
    parser.add_option('-c',
                      '--load',
                      dest='load',
                      default=False,
                      help='load file model')

    (options, args) = parser.parse_args()

    net = UNet(3, 1)

    if options.load:
        net.load_state_dict(torch.load(options.load))
        print('Model loaded from {}'.format(options.load))

    if options.gpu:
        net.cuda()
        cudnn.benchmark = True

    try:
        train_net(net,
                  options.epochs,
                  options.batchsize,
                  options.lr,
                  gpu=options.gpu)
    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        print('Saved interrupt')
Exemplo n.º 12
0
    parser.add_argument('--mask_threshold', type=float, default=0.5, help="Minimum probability value to consider a mask pixel white")
    parser.add_argument('--scale', type=float, default=1.0, help="downscale factor for the input images")

    args = parser.parse_args()

    ckpt_str = args.weights_path.split('/')[-1][:-4]
    logfile = f'logs/segment/{args.dataset_name}_{ckpt_str}.log'
    sys.stdout = Logger(logfile)
    print(args)

    output_dir = args.output_dir + '/' + args.dataset_name
    os.makedirs(output_dir, exist_ok=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    net = UNet(n_channels=1, n_classes=1).to(device)
    net.load_state_dict(torch.load(args.weights_path)['state_dict'])
    net.eval()

    count = 0
    inference_time_total = 0.0
    val_score_total = 0.0
    image_folder = args.image_folder + '/' + args.dataset_name
    for i, fn in enumerate(os.listdir(image_folder)):
        if not fn.endswith('.jpg'):
            continue
        count += 1
        # input image
        img_path = os.path.join(image_folder, fn)
        target_path = img_path.replace('imgs', 'masks')
        
        # single image prediction
def test(experiment_path, test_epoch):
    # ========= CONFIG FILE TO READ FROM =======
    config = configparser.RawConfigParser()
    config.read('./' + experiment_path + '/' + experiment_path + '_config.txt')
    # ===========================================
    # run the training on invariant or local
    path_data = config.get('data paths', 'path_local')
    model = config.get('training settings', 'model')
    # original test images (for FOV selection)
    DRIVE_test_imgs_original = path_data + config.get('data paths', 'test_imgs_original')
    test_imgs_orig = load_hdf5(DRIVE_test_imgs_original)
    full_img_height = test_imgs_orig.shape[2]
    full_img_width = test_imgs_orig.shape[3]
    # the border masks provided by the DRIVE
    DRIVE_test_border_masks = path_data + config.get('data paths', 'test_border_masks')
    test_border_masks = load_hdf5(DRIVE_test_border_masks)
    # dimension of the patches
    patch_height = int(config.get('data attributes', 'patch_height'))
    patch_width = int(config.get('data attributes', 'patch_width'))
    # the stride in case output with average
    stride_height = int(config.get('testing settings', 'stride_height'))
    stride_width = int(config.get('testing settings', 'stride_width'))
    assert (stride_height < patch_height and stride_width < patch_width)
    # model name
    name_experiment = config.get('experiment name', 'name')
    path_experiment = './' + name_experiment + '/'
    # N full images to be predicted
    Imgs_to_test = int(config.get('testing settings', 'full_images_to_test'))
    # Grouping of the predicted images
    N_visual = int(config.get('testing settings', 'N_group_visual'))
    # ====== average mode ===========
    average_mode = config.getboolean('testing settings', 'average_mode')
    #N_subimgs = int(config.get('training settings', 'N_subimgs'))
    #batch_size = int(config.get('training settings', 'batch_size'))
    #epoch_size = N_subimgs // (batch_size)
    # #ground truth
    # gtruth= path_data + config.get('data paths', 'test_groundTruth')
    # img_truth= load_hdf5(gtruth)
    # visualize(group_images(test_imgs_orig[0:20,:,:,:],5),'original')#.show()
    # visualize(group_images(test_border_masks[0:20,:,:,:],5),'borders')#.show()
    # visualize(group_images(img_truth[0:20,:,:,:],5),'gtruth')#.show()

    # ============ Load the data and divide in patches
    patches_imgs_test = None
    new_height = None
    new_width = None
    masks_test = None
    patches_masks_test = None

    if average_mode == True:
        patches_imgs_test, new_height, new_width, masks_test= get_data_testing_overlap(
            DRIVE_test_imgs_original = DRIVE_test_imgs_original, #original'DRIVE_datasets_training_testing/test_hard_masks.npy'
            DRIVE_test_groudTruth = path_data + config.get('data paths', 'test_groundTruth'),  #masks
            Imgs_to_test = int(config.get('testing settings', 'full_images_to_test')),
            patch_height = patch_height,
            patch_width = patch_width,
            stride_height = stride_height,
            stride_width = stride_width)
    else:
        patches_imgs_test, patches_masks_test = get_data_testing_test(
            DRIVE_test_imgs_original = DRIVE_test_imgs_original,  #original
            DRIVE_test_groudTruth = path_data + config.get('data paths', 'test_groundTruth'),  #masks
            Imgs_to_test = int(config.get('testing settings', 'full_images_to_test')),
            patch_height = patch_height,
            patch_width = patch_width
        )
    #np.save(path_experiment + 'test_patches.npy', patches_imgs_test)
    #visualize(group_images(patches_imgs_test,100),'./'+name_experiment+'/'+"test_patches")

    # ================ Run the prediction of the patches ==================================
    best_last = config.get('testing settings', 'best_last')
    # Load the saved model
    if model == 'UNet':
        net = UNet(n_channels=1, n_classes=2)
    elif model == 'UNet_cat':
        net = UNet_cat(n_channels=1, n_classes=2)
    else:
        net = UNet_level4_our(n_channels=1, n_classes=2)
    # load data
    test_data = data.TensorDataset(torch.tensor(patches_imgs_test),torch.zeros(patches_imgs_test.shape[0]))
    test_loader = data.DataLoader(test_data, batch_size=1, pin_memory=True, shuffle=False)
    trained_model = path_experiment + 'DRIVE_' + str(test_epoch) + 'epoch.pth'
    print(trained_model)
    # trained_model= path_experiment+'DRIVE_unet2_B'+str(60*epoch_size)+'.pth'
    net.load_state_dict(torch.load(trained_model))
    net.eval()
    print('Finished loading model :' + trained_model)
    net = net.cuda()
    cudnn.benchmark = True
    # Calculate the predictions
    predictions_out = np.empty((patches_imgs_test.shape[0],patch_height*patch_width,2))
    for i_batch, (images, targets) in enumerate(test_loader):
        images = Variable(images.float().cuda())
        out1= net(images)

        pred = out1.permute(0,2,3,1)

        pred = F.softmax(pred, dim=-1)

        pred = pred.data.view(-1,patch_height*patch_width,2)

        predictions_out[i_batch] = pred

    # ===== Convert the prediction arrays in corresponding images
    pred_patches_out = pred_to_imgs(predictions_out, patch_height, patch_width, "original")
    #np.save(path_experiment + 'pred_patches_' + str(test_epoch) + "_epoch" + '.npy', pred_patches_out)
    #visualize(group_images(pred_patches_out,100),'./'+name_experiment+'/'+"pred_patches")


    #========== Elaborate and visualize the predicted images ====================
    pred_imgs_out = None
    orig_imgs = None
    gtruth_masks = None
    if average_mode == True:
        pred_imgs_out = recompone_overlap(pred_patches_out,new_height,new_width, stride_height, stride_width)
        orig_imgs = my_PreProc(test_imgs_orig[0:pred_imgs_out.shape[0],:,:,:])    #originals
        gtruth_masks = masks_test  #ground truth masks
    else:
        pred_imgs_out = recompone(pred_patches_out,10,9)       # predictions
        orig_imgs = recompone(patches_imgs_test,10,9)  # originals
        gtruth_masks = recompone(patches_masks_test,10,9)  #masks

    # apply the DRIVE masks on the repdictions #set everything outside the FOV to zero!!
    # DRIVE MASK  #only for visualization
    kill_border(pred_imgs_out, test_border_masks)
    # back to original dimensions
    orig_imgs = orig_imgs[:,:,0:full_img_height,0:full_img_width]
    pred_imgs_out = pred_imgs_out[:, :, 0:full_img_height, 0:full_img_width]
    gtruth_masks = gtruth_masks[:, :, 0:full_img_height, 0:full_img_width]

    print ("Orig imgs shape: "+str(orig_imgs.shape))
    print("pred imgs shape: " + str(pred_imgs_out.shape))
    print("Gtruth imgs shape: " + str(gtruth_masks.shape))
    np.save(path_experiment + 'pred_img_' + str(test_epoch) + "_epoch" + '.npy',pred_imgs_out)
    # visualize(group_images(orig_imgs,N_visual),path_experiment+"all_originals")#.show()
    if average_mode == True:
        visualize(group_images(pred_imgs_out, N_visual),
                  path_experiment + "all_predictions_" + str(test_epoch) + "thresh_epoch")
    else:
        visualize(group_images(pred_imgs_out, N_visual),
                  path_experiment + "all_predictions_" + str(test_epoch) + "epoch_no_average")
    visualize(group_images(gtruth_masks, N_visual), path_experiment + "all_groundTruths")

    # visualize results comparing mask and prediction:
    # assert (orig_imgs.shape[0] == pred_imgs_out.shape[0] and orig_imgs.shape[0] == gtruth_masks.shape[0])
    # N_predicted = orig_imgs.shape[0]
    # group = N_visual
    # assert (N_predicted%group == 0)
    

    # ====== Evaluate the results
    print("\n\n========  Evaluate the results =======================")
   
    # predictions only inside the FOV
    y_scores, y_true = pred_only_FOV(pred_imgs_out, gtruth_masks, test_border_masks)  # returns data only inside the FOV
    '''
    print("Calculating results only inside the FOV:")
    print("y scores pixels: " + str(
        y_scores.shape[0]) + " (radius 270: 270*270*3.14==228906), including background around retina: " + str(
        pred_imgs_out.shape[0] * pred_imgs_out.shape[2] * pred_imgs_out.shape[3]) + " (584*565==329960)")
    print("y true pixels: " + str(
        y_true.shape[0]) + " (radius 270: 270*270*3.14==228906), including background around retina: " + str(
        gtruth_masks.shape[2] * gtruth_masks.shape[3] * gtruth_masks.shape[0]) + " (584*565==329960)")
    '''
    # Area under the ROC curve
    fpr, tpr, thresholds = roc_curve((y_true), y_scores)
    AUC_ROC = roc_auc_score(y_true, y_scores)
    # test_integral = np.trapz(tpr,fpr) #trapz is numpy integration
    print("\nArea under the ROC curve: " + str(AUC_ROC))
    rOc_curve = plt.figure()
    plt.plot(fpr, tpr, '-', label='Area Under the Curve (AUC = %0.4f)' % AUC_ROC)
    plt.title('ROC curve')
    plt.xlabel("FPR (False Positive Rate)")
    plt.ylabel("TPR (True Positive Rate)")
    plt.legend(loc="lower right")
    plt.savefig(path_experiment + "ROC.png")

    # Precision-recall curve
    precision, recall, thresholds = precision_recall_curve(y_true, y_scores)
    precision = np.fliplr([precision])[0]  # so the array is increasing (you won't get negative AUC)
    recall = np.fliplr([recall])[0]  # so the array is increasing (you won't get negative AUC)
    AUC_prec_rec = np.trapz(precision, recall)
    print("\nArea under Precision-Recall curve: " + str(AUC_prec_rec))
    prec_rec_curve = plt.figure()
    plt.plot(recall, precision, '-', label='Area Under the Curve (AUC = %0.4f)' % AUC_prec_rec)
    plt.title('Precision - Recall curve')
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.legend(loc="lower right")
    plt.savefig(path_experiment + "Precision_recall.png")

    # Confusion matrix
    threshold_confusion = 0.5
    print("\nConfusion matrix:  Custom threshold (for positive) of " + str(threshold_confusion))
    y_pred = np.empty((y_scores.shape[0]))
    for i in range(y_scores.shape[0]):
        if y_scores[i] >= threshold_confusion:
            y_pred[i] = 1
        else:
            y_pred[i] = 0
    confusion = confusion_matrix(y_true, y_pred)
    print(confusion)
    accuracy = 0
    if float(np.sum(confusion)) != 0:
        accuracy = float(confusion[0, 0] + confusion[1, 1]) / float(np.sum(confusion))
    print("Global Accuracy: " + str(accuracy))
    specificity = 0
    if float(confusion[0, 0] + confusion[0, 1]) != 0:
        specificity = float(confusion[0, 0]) / float(confusion[0, 0] + confusion[0, 1])
    print("Specificity: " + str(specificity))
    sensitivity = 0
    if float(confusion[1, 1] + confusion[1, 0]) != 0:
        sensitivity = float(confusion[1, 1]) / float(confusion[1, 1] + confusion[1, 0])
    print("Sensitivity: " + str(sensitivity))
    precision = 0
    if float(confusion[1, 1] + confusion[0, 1]) != 0:
        precision = float(confusion[1, 1]) / float(confusion[1, 1] + confusion[0, 1])
    print("Precision: " + str(precision))

    # Jaccard similarity index
    jaccard_index = jaccard_similarity_score(y_true, y_pred, normalize=True)
    print("\nJaccard similarity score: " + str(jaccard_index))

    # F1 score
    F1_score = f1_score(y_true, y_pred, labels=None, average='binary', sample_weight=None)
    print("\nF1 score (F-measure): " + str(F1_score))
    ####evaluate the thin vessels
    thin_3pixel_recall_indivi = []
    thin_3pixel_auc_roc = []
    for j in range(pred_imgs_out.shape[0]):
        thick3=opening(gtruth_masks[j, 0, :, :], square(3))
        thin_gt = gtruth_masks[j, 0, :, :] - thick3
        
        thin_pred=pred_imgs_out[j, 0, :, :]
        
        thin_pred[thick3==1]=0
        thin_3pixel_recall_indivi.append(round(thin_recall(thin_gt, pred_imgs_out[j, 0, :, :], thresh=0.5), 4))
        thin_3pixel_auc_roc.append(round(roc_auc_score(thin_gt.flatten(), thin_pred.flatten()), 4))
    thin_2pixel_recall_indivi = []
    thin_2pixel_auc_roc = []
    for j in range(pred_imgs_out.shape[0]):
        thick=opening(gtruth_masks[j, 0, :, :], square(2))
        thin_gt = gtruth_masks[j, 0, :, :] - thick
        #thin_gt_only=thin_gt[thin_gt==1]
        #print(thin_gt_only)
        thin_pred=pred_imgs_out[j, 0, :, :]
        #thin_pred=thin_pred[thin_gt==1]
        thin_pred[thick==1]=0
        thin_2pixel_recall_indivi.append(round(thin_recall(thin_gt, pred_imgs_out[j, 0, :, :], thresh=0.5), 4))
        thin_2pixel_auc_roc.append(round(roc_auc_score(thin_gt.flatten(), thin_pred.flatten()), 4))
    
    #print("thin 2vessel recall:", thin_2pixel_recall_indivi)
    #print('thin 2vessel auc score', thin_2pixel_auc_roc)
    # Save the results
    with open(path_experiment + 'test_performances_all_epochs.txt', mode='a') as f:
        f.write("\n\n" + path_experiment + " test epoch:" + str(test_epoch)
                + '\naverage mode is:' + str(average_mode)
                + "\nArea under the ROC curve: %.4f" % (AUC_ROC)
                + "\nArea under Precision-Recall curve: %.4f" % (AUC_prec_rec)
                + "\nJaccard similarity score: %.4f" % (jaccard_index)
                + "\nF1 score (F-measure): %.4f" % (F1_score)
                + "\nConfusion matrix:"
                + str(confusion)
                + "\nACCURACY: %.4f" % (accuracy)
                + "\nSENSITIVITY: %.4f" % (sensitivity)
                + "\nSPECIFICITY: %.4f" % (specificity)
                + "\nPRECISION: %.4f" % (precision)
                + "\nthin 2vessels recall indivi:\n" + str(thin_2pixel_recall_indivi)
                + "\nthin 2vessels recall mean:%.4f" % (np.mean(thin_2pixel_recall_indivi))
                + "\nthin 2vessels auc indivi:\n" + str(thin_2pixel_auc_roc)
                + "\nthin 2vessels auc score mean:%.4f" % (np.mean(thin_2pixel_auc_roc))
                + "\nthin 3vessels recall indivi:\n" + str(thin_3pixel_recall_indivi)
                + "\nthin 3vessels recall mean:%.4f" % (np.mean(thin_3pixel_recall_indivi))
                + "\nthin 3vessels auc indivi:\n" + str(thin_3pixel_auc_roc)
                + "\nthin 3vessels auc score mean:%.4f" % (np.mean(thin_3pixel_auc_roc))
                )
Exemplo n.º 14
0
    sample_le = {}
    for le_512 in LE_512:
        tiles = crop_prepare(le_512, CROP_STEP, IMG_SIZE)
        for n, img in enumerate(tiles):
            if n not in sample_le:
                sample_le[n] = []
            img = transform.resize(img, (IMG_SIZE * 2, IMG_SIZE * 2),
                                   preserve_range=True,
                                   order=3)
            sample_le[n].append(img)

    SC_UNET = UNet(n_channels=15, n_classes=1)
    print("{} paramerters in total".format(
        sum(x.numel() for x in SC_UNET.parameters())))
    SC_UNET.cuda(cuda)
    SC_UNET.load_state_dict(torch.load(model_path))
    # SC_UNET.load_state_dict(torch.load(os.path.join(dir_path,"model","HE_HER_mito","HE_X2_HER_0825.pkl")))
    SC_UNET.eval()

    SRRFDATASET = ReconsDataset(img_dict=sample_he,
                                transform=ToTensor(),
                                in_norm=LE_in_norm,
                                img_type=".tif",
                                in_size=256)
    test_dataloader = torch.utils.data.DataLoader(
        SRRFDATASET, batch_size=1, shuffle=False,
        pin_memory=True)  # better than for loop
    result = np.zeros((256, 256, len(SRRFDATASET)))
    for batch_idx, items in enumerate(test_dataloader):
        image = items['image_in']
        image_idx = items['image_name']
Exemplo n.º 15
0
test_dir = dir_names.test_dir + '/' + experiment_name + '/'

if not os.path.exists(test_dir):
    os.makedirs(test_dir)

if not load_model:
    c.force_create(model_dir)
    c.force_create(tfboard_dir)

#Define image dataset (reads in full images and segmentations)
test_dataset = p.ImageDataset_withPrior(csv_file=c.final_test_csv)

num_class = 3
model_file = model_dir + 'model_17.pth'
net = UNet(num_class, in_channels=2)
net.load_state_dict(torch.load(model_file, map_location=device))
net.eval()
net.to(device)

pad_size = c.half_patch[0]
include_prior = True

with torch.no_grad():
    for i in range(6, len(test_dataset)):

        sample = test_dataset[i]
        if include_prior:
            prior = sample['prior']
            test_patches = p.GeneratePatches(sample,
                                             is_training=False,
                                             transform=False,
Exemplo n.º 16
0
import os
from PIL import Image
from predict import *
from utils import encode
from unet_model import UNet


def submit(net, gpu=False):
    dir = 'data/test/'

    N = len(list(os.listdir(dir)))
    with open('SUBMISSION.csv', 'a') as f:

        f.write('img,rle_mask\n')
        for index, i in enumerate(os.listdir(dir)):
            print('{}/{}'.format(index, N))

            img = Image.open(dir + i)

            mask = predict_img(net, img, gpu)
            enc = rle_encode(mask)
            f.write('{},{}\n'.format(i, ' '.join(map(str, enc))))


if __name__ == '__main__':
    net = UNet(3, 1).cuda()
    net.load_state_dict(torch.load('MODEL.pth'))
    submit(net, True)
Exemplo n.º 17
0
            loss = criterion(y_pred, y.unsqueeze(0).float())

            l += loss.data[0]
            loss.backward()
            if i % 10 == 0:
                optimizer.step()
                print('Stepped')

            print('{0:.4f}%\t\t{1:.6f}'.format(i / len(ids) * 100,
                                               loss.data[0]))

        l = l / len(ids)
        print('Loss : {}'.format(l))
        torch.save(net.state_dict(),
                   'MODEL_EPOCH{}_LOSS{}.pth'.format(epoch + 1, l))
        print('Saved')


try:
    net.load_state_dict(torch.load('MODEL_INTERRUPTED.pth'))
    train(net)

except KeyboardInterrupt:
    print('Interrupted')
    torch.save(net.state_dict(), 'MODEL_INTERRUPTED.pth')
    try:
        sys.exit(0)
    except SystemExit:
        os._exit(0)
Exemplo n.º 18
0
        if convert_to_2d:
            inp_set = get_2d_converted_data(inp_set)
        inp_set = torch.from_numpy(inp_set).float()
        file_name = os.path.basename(inp_file)
        out_file = os.path.join(out_dir, file_name)
        data.append((inp_set, out_file))
    return data

def save_pred(model, data):
    model.eval()
    for image, file_path in data:
        img = image.cuda(cuda)
        pred = model(img)        
        pred = pred.detach().cpu().numpy()[0]
        pred = (pred * 255) .astype(np.uint8)
        # save_path = file_path.replace('.mat', '.png')
        # cv2.imwrite(save_path, pred)
        pred = pred.transpose((1, 2, 0))
        savemat(file_path, {'crop_g': pred})

if __name__ == '__main__':
    cuda = torch.device('cuda')
    model = UNet(n_channels=45, n_classes=3)
    print("{} Parameters in total".format(sum(x.numel() for x in model.parameters())))
    model.cuda(cuda)
    model.load_state_dict(torch.load(model_loc+"Model_Final_999_3_5.pkl"))
    model.eval()
    model.cuda(cuda)
    data = get_images()
    save_pred(model, data)
Exemplo n.º 19
0
    img = transform(img)
    img = img.unsqueeze(0)


    def get_layer_param(model):
        return sum([torch.numel(param) for param in model.parameters()])


    net = UNet(1, 3).to(device)
    print(net)
    print('parameters:', get_layer_param(net))

    print("Loading checkpoint...")
    checkpoint = torch.load(ckpt_path)
    net.load_state_dict(checkpoint['net_state_dict'])
    net.eval()

    print("Starting Test...")
    # -----------------------------------------------------------
    # Initial batch
    data_A = img.to(device)
    # -----------------------------------------------------------
    # Generate fake img:
    fake_B = net(data_A)
    # -----------------------------------------------------------
    # Output training stats
    # vutils.save_image(data_A, os.path.join(samples_path, 'result', '%s_data_A.jpg' % str(i).zfill(6)),
    #                   padding=2, nrow=2, normalize=True)
    vutils.save_image(fake_B, os.path.join('./', '%s_fake_B_leaky.jpg' % filename[0:-4]),
                      padding=0, nrow=1, normalize=True)
Exemplo n.º 20
0
def train_second_layer(model, criterion, optimizer, scheduler, num_epochs=25):
    model1 = UNet()
    model2 = UNet()
    model1.load_state_dict(best_model_wts)
    model2.load_state_dict(best_model_wts)
Exemplo n.º 21
0
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import make_grid, save_image
from PIL import Image
import os
import numpy as np
from PIL import Image
from unet_model import UNet
import random

input_dir = "../Test_Data/"
output_dir = "../Generated_Test/"
model_path = "models/model_gen_latest"

generator = UNet(n_channels=3, n_classes=2)
generator.load_state_dict(torch.load(model_path))
generator.eval()

for filename in random.sample(os.listdir(input_dir),
                              len(os.listdir(input_dir))):

    img = Image.open(os.path.join(input_dir, filename))
    # img = normalize(img)
    img = torch.stack([
        transforms.Compose(
            [transforms.Resize((75, 210)),
             transforms.ToTensor()])(img)
    ])

    output_img = generator(img)
    save_image(output_img, output_dir + "/" + filename)
Exemplo n.º 22
0
    pin_memory=False)  # better than for loop
test_dataloader = torch.utils.data.DataLoader(
    test_data, batch_size=batch_size, shuffle=True,
    pin_memory=False)  # better than for loop

if is_3d and not (convert_to_2d, data_reduced):  # 3D Processing directly
    model = UNet3D(n_channels=X_train.shape[1], n_classes=y_train.shape[1])
if data_reduced and single_unet:  # 3D Converted to 2D with data reduction of reduced_size
    model = UNet3SIM(n_channels=X_train.shape[1], n_classes=y_train.shape[1])
else:
    model = UNet(n_channels=X_train.shape[1], n_classes=y_train.shape[1])
## Thats tautology TO-DO - fix and optimize
start_epoch = 0
if load_model:
    weight = torch.load(model_loc + "650.pt")
    model.load_state_dict(weight['model_state_dict'])
    start_epoch = weight['epoch']

print("{} Parameters in Total".format(
    sum(x.numel() for x in model.parameters())))
print(" See model input shape first:", X_train.shape[1], y_train.shape[1])
if have_cuda:
    model.cuda(cuda)
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(),
                             lr=learning_rate,
                             betas=(0.9, 0.999))
train_loss = []
learn_rate = []
c = 0
val_metrics = {
Exemplo n.º 23
0
class Train(object):
    def __init__(self, configs):
        self.batch_size = configs.get("batch_size", "16")
        self.epochs = configs.get("epochs", "100")
        self.lr = configs.get("lr", "0.0001")

        device_args = configs.get("device", "cuda")
        self.device = torch.device(
            "cpu" if not torch.cuda.is_available() else device_args)

        self.workers = configs.get("workers", "4")

        self.vis_images = configs.get("vis_images", "200")
        self.vis_freq = configs.get("vis_freq", "10")

        self.weights = configs.get("weights", "./weights")
        if not os.path.exists(self.weights):
            os.mkdir(self.weights)

        self.logs = configs.get("logs", "./logs")
        if not os.path.exists(self.weights):
            os.mkdir(self.weights)

        self.images_path = configs.get("images_path", "./data")

        self.is_resize = config.get("is_resize", False)
        self.image_short_side = config.get("image_short_side", 256)

        self.is_padding = config.get("is_padding", False)

        is_multi_gpu = config.get("DateParallel", False)

        pre_train = config.get("pre_train", False)
        model_path = config.get("model_path", './weights/unet_idcard_adam.pth')

        # self.image_size = configs.get("image_size", "256")
        # self.aug_scale = configs.get("aug_scale", "0.05")
        # self.aug_angle = configs.get("aug_angle", "15")

        self.step = 0

        self.dsc_loss = DiceLoss()
        self.model = UNet(in_channels=Dataset.in_channels,
                          out_channels=Dataset.out_channels)
        if pre_train:
            self.model.load_state_dict(torch.load(model_path,
                                                  map_location=self.device),
                                       strict=False)

        if is_multi_gpu:
            self.model = nn.DataParallel(self.model)

        self.model.to(self.device)

        self.best_validation_dsc = 0.0

        self.loader_train, self.loader_valid = self.data_loaders()

        self.params = [p for p in self.model.parameters() if p.requires_grad]

        self.optimizer = optim.Adam(self.params,
                                    lr=self.lr,
                                    weight_decay=0.0005)
        # self.optimizer = torch.optim.SGD(self.params, lr=self.lr, momentum=0.9, weight_decay=0.0005)
        self.scheduler = lr_scheduler.LR_Scheduler_Head(
            'poly', self.lr, self.epochs, len(self.loader_train))

    def datasets(self):
        train_datasets = Dataset(
            images_dir=self.images_path,
            # image_size=self.image_size,
            subset="train",  # train
            transform=get_transforms(train=True),
            is_resize=self.is_resize,
            image_short_side=self.image_short_side,
            is_padding=self.is_padding)
        # valid_datasets = train_datasets

        valid_datasets = Dataset(
            images_dir=self.images_path,
            # image_size=self.image_size,
            subset="validation",  # validation
            transform=get_transforms(train=False),
            is_resize=self.is_resize,
            image_short_side=self.image_short_side,
            is_padding=False)
        return train_datasets, valid_datasets

    def data_loaders(self):
        dataset_train, dataset_valid = self.datasets()

        loader_train = DataLoader(
            dataset_train,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=self.workers,
        )
        loader_valid = DataLoader(
            dataset_valid,
            batch_size=1,
            drop_last=False,
            num_workers=self.workers,
        )

        return loader_train, loader_valid

    @staticmethod
    def dsc_per_volume(validation_pred, validation_true):
        assert len(validation_pred) == len(validation_true)
        dsc_list = []
        for p in range(len(validation_pred)):
            y_pred = np.array([validation_pred[p]])
            y_true = np.array([validation_true[p]])
            dsc_list.append(dsc(y_pred, y_true))
        return dsc_list

    @staticmethod
    def get_logger(filename, verbosity=1, name=None):
        level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
        formatter = logging.Formatter(
            "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
        )
        logger = logging.getLogger(name)
        logger.setLevel(level_dict[verbosity])

        fh = logging.FileHandler(filename, "w")
        fh.setFormatter(formatter)
        logger.addHandler(fh)

        sh = logging.StreamHandler()
        sh.setFormatter(formatter)
        logger.addHandler(sh)

        return logger

    def train_one_epoch(self, epoch):

        self.model.train()
        loss_train = []
        for i, data in enumerate(self.loader_train):
            self.scheduler(self.optimizer, i, epoch, self.best_validation_dsc)
            x, y_true = data
            x, y_true = x.to(self.device), y_true.to(self.device)

            y_pred = self.model(x)
            # print('1111', y_pred.size())
            # print('2222', y_true.size())
            loss = self.dsc_loss(y_pred, y_true)

            loss_train.append(loss.item())

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # lr_scheduler.step()
            if self.step % 200 == 0:
                print('Epoch:[{}/{}]\t iter:[{}]\t loss={:.5f}\t '.format(
                    epoch, self.epochs, i, loss))

            self.step += 1

    def eval_model(self, patience):
        self.model.eval()
        loss_valid = []

        validation_pred = []
        validation_true = []
        # early_stopping = EarlyStopping(patience=patience, verbose=True)

        for i, data in enumerate(self.loader_valid):
            x, y_true = data
            x, y_true = x.to(self.device), y_true.to(self.device)

            # print(x.size())
            # print(333,x[0][2])
            with torch.no_grad():
                y_pred = self.model(x)
                loss = self.dsc_loss(y_pred, y_true)

            # print(y_pred.shape)
            mask = y_pred > 0.5
            mask = mask * 255
            mask = mask.cpu().numpy()[0][0]
            # print(mask)
            # print(mask.shape())
            cv2.imwrite('result.png', mask)

            loss_valid.append(loss.item())

            y_pred_np = y_pred.detach().cpu().numpy()

            validation_pred.extend(
                [y_pred_np[s] for s in range(y_pred_np.shape[0])])
            y_true_np = y_true.detach().cpu().numpy()
            validation_true.extend(
                [y_true_np[s] for s in range(y_true_np.shape[0])])

        # early_stopping(loss_valid, self.model)
        # if early_stopping.early_stop:
        #     print('Early stopping')
        #     import sys
        #     sys.exit(1)
        mean_dsc = np.mean(
            self.dsc_per_volume(
                validation_pred,
                validation_true,
            ))
        # print('mean_dsc:', mean_dsc)
        if mean_dsc > self.best_validation_dsc:
            self.best_validation_dsc = mean_dsc
            torch.save(self.model.state_dict(),
                       os.path.join(self.weights, "unet_xia_adam.pth"))
            print("Best validation mean DSC: {:4f}".format(
                self.best_validation_dsc))

    def main(self):
        # print('train is begin.....')
        # print('load data end.....')

        # loaders = {"train": loader_train, "valid": loader_valid}

        for epoch in tqdm(range(self.epochs), total=self.epochs):
            self.train_one_epoch(epoch)
            self.eval_model(patience=10)

        torch.save(self.model.state_dict(),
                   os.path.join(self.weights, "unet_final.pth"))
Exemplo n.º 24
0
class OnePredict(object):
    def __init__(self, params):
        self.params = params

        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        self.model_path = params['model_path']

        self.model = UNet(in_channels=3, out_channels=1)

        self.threshold = 0.5

        self.resume()
        # self.model.eval()

        self.transform = get_transforms_3()

        self.is_resize = True
        self.image_short_side = 1024
        self.init_torch_tensor()
        self.model.eval()

    def init_torch_tensor(self):
        torch.set_default_tensor_type('torch.FloatTensor')
        if torch.cuda.is_available():
            self.device = torch.device('cuda')
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            self.device = torch.device('cpu')
        # self.model.to(self.device)

    def resume(self):
        self.model.load_state_dict(torch.load(self.model_path, map_location=self.device), strict=False)
        self.model.to(self.device)

    def resize_img(self, img):
        '''输入PIL格式的图片'''
        width, height = img.size
        # print('111', img.size)
        if self.is_resize:
            if height < width:
                new_height = self.image_short_side
                new_width = int(math.ceil(new_height / height * width / 32) * 32)
            else:
                new_width = self.image_short_side
                new_height = int(math.ceil(new_width / width * height / 32) * 32)
        else:
            if height < width:
                scale = int(height / 32)
                new_image_short_side = scale * 32
                new_height = new_image_short_side
                new_width = int(math.ceil(new_height / height * width / 32) * 32)
            else:
                scale = int(width / 32)
                new_image_short_side = scale * 32
                new_width = new_image_short_side
                new_height = int(math.ceil(new_width / width * height / 32) * 32)
        # print('test1:', np.array(img))
        # print('new:', (new_width, new_height))
        resized_img = img.resize((new_width, new_height), Image.ANTIALIAS)
        # print(new_height, new_width)
        # print('test2:', np.array(resized_img))
        return resized_img

    def format_output(self):
        pass

    @staticmethod
    def pre_process(img):
        return img

    @staticmethod
    def pad_sample(img):
        a = img.size[0]
        b = img.size[1]
        if a == b:
            return img
        diff = (max(a, b) - min(a, b)) / 2.0
        if a > b:
            padding = (0, int(np.floor(diff)), 0, int(np.ceil(diff)))
        else:
            padding = (int(np.floor(diff)), 0, int(np.ceil(diff)), 0)

        img = ImageOps.expand(img, border=padding, fill=0)  ##left,top,right,bottom

        assert img.size[0] == img.size[1]
        return img

    def post_process(self, preds, img):
        mask = preds > self.threshold
        mask = mask * 255
        # print(mask.size())
        mask = mask.cpu().numpy()[0][0]
        # print(mask)
        # print(mask.shape())
        cv2.imwrite('mask.png', mask)

        mask = np.array(mask, np.uint8)

        contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        # print(contours)

        # img = img.cpu()
        img = np.array(img, np.uint8)

        cv2.drawContours(img, contours, -1, (0, 0, 255), 1)

        cv2.imwrite('result2.png', img)
        boxes = []

        return boxes

    @staticmethod
    def demo_visualize():
        pass

    def inference(self, img_path, is_visualize=True, is_format_output=False):
        img = cv2.imread(img_path, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(img).convert("RGB")
        # img = Image.open(img_path).convert("RGB")
        # print('222', np.array(img))
        # img = self.pad_sample(img)
        img = self.resize_img(img)
        # print('333', img.size)
        # print('-----', np.array(img))
        ori_img = img
        img.save('img.png')
        # img = [img]
        print('111', np.array(img))
        img = self.transform(img)
        print('222', np.array(img))
        img = img.unsqueeze(0)
        img = img.to(self.device)
        # print('1111', img.size())
        # print(img)

        # print(img)
        with torch.no_grad():
            s1 = time.time()
            preds = self.model(img)
            print(preds)
            s2 = time.time()
            print(s2 - s1)
            # boxes, scores = SegDetectorRepresenter().represent(pred=preds, height=h, width=w, is_output_polygon=False)
            boxes = self.post_process(preds, ori_img)
Exemplo n.º 25
0
	os.makedirs(save_dir)
if not os.path.exists(model_path):
	os.makedirs(model_path)

generator = UNet(n_channels=3, out_channels=1)
discriminator_g = GlobalDiscriminator()
discriminator_l = LocalDiscriminator()

resume=False
if(len(sys.argv)>1 and sys.argv[1]=='resume'):
	resume=True
	
# Load model if available
if(resume==True):
	print('Resuming training....')
	generator.load_state_dict(torch.load(os.path.join(model_path,'model_gen_latest')))
	discriminator_g.load_state_dict(torch.load(os.path.join(model_path,'model_gdis_latest')))
	discriminator_l.load_state_dict(torch.load(os.path.join(model_path,'model_ldis_latest')))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
generator = generator.to(device)
discriminator_g = discriminator_g.to(device)
discriminator_l = discriminator_l.to(device)

optimizer_g = optim.Adam(discriminator_g.parameters(), lr=0.00005)
optimizer_l = optim.Adam(discriminator_l.parameters(), lr=0.00005)
gen_optimizer = optim.Adam(generator.parameters(), lr=0.0002)

lossdis = nn.BCELoss()
lossgen = FocalLoss()
lamda = 75
Exemplo n.º 26
0
    smooth = 1.
    num = pred.size(0)
    m1 = pred.view(num, -1)  # Flatten
    m2 = target.view(num, -1)  # Flatten
    intersection = (m1 * m2).sum()
    return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)

if __name__ == "__main__":
    # 选择设备,有cuda用cuda,没有就用cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 加载网络,图片单通道,分类为1。
    net = UNet(n_channels=3, n_classes=1)
    # 将网络拷贝到deivce中
    net.to(device=device)
    # 加载模型参数
    net.load_state_dict(torch.load('./save_model/CP_epoch300.pth', map_location=device))
    # 测试模式
    net.eval()
    # 读取所有图片路径
    data_path = "./data/CHASE/test"
    tests_path = glob.glob(os.path.join(data_path, r'image/*.jpg'))
    name_dataset = DealDataset(data_path)
    train_loader = torch.utils.data.DataLoader(dataset=name_dataset, batch_size=1,shuffle=False)  # shuffle 填True 就会打乱
    # print(train_loader.len)
    # save_res_path = test_path.split('.')[0] + '_res2021.png'
    # 遍历所有图片
    total_batch = int(len(name_dataset)/1)
    bar = tqdm(enumerate(train_loader),total=total_batch)
    for batch_index, batch in bar:
        image = batch['image']
        # label = batch['label'].squeeze(0)