Ejemplo n.º 1
0
def main(argv):
    with CytomineJob.from_cli(argv) as job:
        model_path = os.path.join(str(Path.home()), "models", "thyroid-unet")
        model_filepath = pick_model(model_path, job.parameters.tile_size,
                                    job.parameters.cytomine_zoom_level)
        device = torch.device(job.parameters.device)
        unet = Unet(job.parameters.init_fmaps, n_classes=1)
        unet.load_state_dict(torch.load(model_filepath, map_location=device))
        unet.to(device)
        unet.eval()

        segmenter = UNetSegmenter(device=job.parameters.device,
                                  unet=unet,
                                  classes=[0, 1],
                                  threshold=job.parameters.threshold)

        working_path = os.path.join(str(Path.home()), "tmp")
        tile_builder = CytomineTileBuilder(working_path)
        builder = SSLWorkflowBuilder()
        builder.set_n_jobs(1)
        builder.set_overlap(job.parameters.tile_overlap)
        builder.set_tile_size(job.parameters.tile_size,
                              job.parameters.tile_size)
        builder.set_tile_builder(tile_builder)
        builder.set_border_tiles(Workflow.BORDER_TILES_EXTEND)
        builder.set_background_class(0)
        builder.set_distance_tolerance(1)
        builder.set_seg_batch_size(job.parameters.batch_size)
        builder.set_segmenter(segmenter)
        workflow = builder.get()

        slide = CytomineSlide(img_instance=ImageInstance().fetch(
            job.parameters.cytomine_id_image),
                              zoom_level=job.parameters.cytomine_zoom_level)
        results = workflow.process(slide)

        print("-------------------------")
        print(len(results))
        print("-------------------------")

        collection = AnnotationCollection()
        for obj in results:
            wkt = shift_poly(obj.polygon,
                             slide,
                             zoom_level=job.parameters.cytomine_zoom_level).wkt
            collection.append(
                Annotation(location=wkt,
                           id_image=job.parameters.cytomine_id_image,
                           id_terms=[154005477],
                           id_project=job.project.id))
        collection.save(n_workers=job.parameters.n_jobs)

        return {}
def test_video():
    cap = cv2.VideoCapture(2)

    model = Unet(3, 1)
    model.load_state_dict(torch.load(Model_path))

    model.to(device)
    #card_dataset = CardDataset("data/val", transform=x_transforms,target_transform=y_transforms)
    #dataloaders = DataLoader(card_dataset, batch_size=1)
    model.eval()

    with torch.no_grad():
        while True:
            ret, frame = cap.read()
            if ret is None:
                print("camera is not ready")
                exit(0)

            frame = frame[0:480, 0:480]
            img = cv2.resize(frame, (512, 512))
            #img = cv2.imread(name,1)

            x = cv2img_process(img)

            x_cuda = x.cuda()

            y = model(x_cuda)
            y_cpu = y.cpu()
            img_y = (torch.squeeze(y_cpu).numpy() * -0.4 * 40 / 255.0 -
                     0.3) / 0.7
            img_y = np.where(img_y < 0.3, 0, img_y)
            img_y = np.where(img_y > 0.3, 1, img_y)

            cv2.imshow("x", img)
            cv2.imshow("predict", img_y)
            #print(img.shape)
            #print(img_y.shape)
            #print("max ",img_y.max())
            #print("min ",img_y.min())
            #print(img_y[250][250])
            cv2.waitKey(1)
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        net = nn.DataParallel(net)
reuse_weights = True
if reuse_weights:
    net.load_state_dict(torch.load('./models/model_{}.pth'.format(name)))
    try:
        best_val_loss = np.load('./models/best_val_loss_{}.npy'.format(name))
    except:
        best_val_loss = np.finfo(np.float64).max
    print("Model reloaded. Previous lowest validation loss =",
          str(best_val_loss))
else:
    best_val_loss = np.finfo(np.float64).max

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, weight_decay=1e-4)

best_weights = net.state_dict()
num_epochs = 5
train_loss = np.zeros(num_epochs)
validation_loss = np.zeros(num_epochs)

print('\nStart training')
np.savetxt('epochs_completed.txt', np.zeros(1), fmt='%d')
for epoch in range(num_epochs):  #TODO decide epochs
    print('-----------------Epoch = %d-----------------' % (epoch + 1))
    train_loss[epoch], _ = train(train_loader, net, criterion, optimizer,
                                 device, epoch + 1)
Ejemplo n.º 4
0
def test():
    model = Unet(5, 2)
    model.load_state_dict(torch.load(args.ckp, map_location='cpu'))
    model.to(device)

    test_root_dir = "test"
    PAVE_dataset = SSFPTestDataset(root=test_root_dir)
    # batch_size has to be divisible by 828 because there are 828 slices per patient
    batch_size = 1

    dataloaders = DataLoader(PAVE_dataset, batch_size=batch_size)
    model.eval()
    #import matplotlib.pyplot as plt
    #plt.ion()

    test_result_dir = "test_result"
    if not os.path.exists(test_result_dir):
        os.makedirs(test_result_dir)

    patients = np.zeros((1, 832, 2, 224, 832))

    with torch.no_grad():
        for x, slice_num, patient_num, leg in tqdm(dataloaders):
            x = x.to(device)
            y = model(x)
            output = y.cpu().numpy()

            if leg[0] == 'left':
                patients[patient_num, slice_num + 2, :, :192,
                         80:400] = output[0, :, :, :]
            else:
                patients[patient_num, slice_num + 2, :, :192,
                         480:800] = output[0, :, :, :]

    for patient_num in range(10):
        image_filename = os.path.join("/home/mng/scratch/PAVE_Challenge/test/",
                                      'case{}'.format(patient_num + 1),
                                      'ssfp.nii.gz')

        size_x = nib.load(image_filename).shape[2]

        patient_output_vessels = np.transpose(
            (patients[patient_num, :, 0, :size_x, :] >= 0.5), axes=(2, 0, 1))
        patient_output_arteries = np.transpose(
            (patients[patient_num, :, 1, :size_x, :] >= 0.5), axes=(2, 0, 1))
        patient_output_veins = np.logical_and(
            patient_output_vessels, np.logical_not(patient_output_arteries))

        results_file = os.path.join(
            test_result_dir,
            'case{}_results_vessels.nii.gz'.format(patient_num + 1))
        save_nii(patient_output_vessels.astype(np.uint8), results_file)

        results_file = os.path.join(
            test_result_dir,
            'case{}_results_arteries.nii.gz'.format(patient_num + 1))
        save_nii(patient_output_arteries.astype(np.uint8), results_file)

        results_file = os.path.join(
            test_result_dir,
            'case{}_results_veins.nii.gz'.format(patient_num + 1))
        save_nii(patient_output_veins.astype(np.uint8), results_file)
Ejemplo n.º 5
0
import torch.optim as optim
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from unet import Unet, Layer
import time
import cv2

MODEL_NAME = f"model-{int(time.time())}"  # gives a dynamic model name, to just help with things getting messy over time.
learning_rate = 0.001
epochs = 4
validation_percentage = 0.1

u_net = Unet()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
u_net.to(device)
optimizer = optim.Adam(u_net.parameters(), lr=learning_rate)
loss_func = nn.MSELoss()


def filter_img(img00, img01):
    kernel = np.ones((4, 4), np.uint8)

    subtract = cv2.subtract((img00 + 15), img01)

    kernel2 = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]])
    img03 = cv2.filter2D(subtract, -1, kernel2)
    img03 = cv2.GaussianBlur(img03, (5, 5), 0)
    img03 = cv2.Canny(img03, 85, 255)
    img03 = cv2.morphologyEx(img03, cv2.MORPH_CLOSE, kernel, iterations=1)
    img03 = cv2.bitwise_not(img03)
Ejemplo n.º 6
0
def main(argv):
    """

    IMAGES VALID:
    * 005-TS_13C08351_2-2014-02-12 12.22.44.ndpi | id : 77150767
    * 024-12C07162_2A-2012-08-14-17.21.05.jp2 | id : 77150761
    * 019-CP_12C04234_2-2012-08-10-12.49.26.jp2 | id : 77150809

    IMAGES TEST:
    * 004-PF_08C11886_1-2012-08-09-19.05.53.jp2 | id : 77150623
    * 011-TS_13C10153_3-2014-02-13 15.22.21.ndpi | id : 77150611
    * 018-PF_07C18435_1-2012-08-17-00.55.09.jp2 | id : 77150755

    """
    with Cytomine.connect_from_cli(argv):
        parser = ArgumentParser()
        parser.add_argument("-b",
                            "--batch_size",
                            dest="batch_size",
                            default=4,
                            type=int)
        parser.add_argument("-j",
                            "--n_jobs",
                            dest="n_jobs",
                            default=1,
                            type=int)
        parser.add_argument("-e",
                            "--epochs",
                            dest="epochs",
                            default=1,
                            type=int)
        parser.add_argument("-d", "--device", dest="device", default="cpu")
        parser.add_argument("-o",
                            "--overlap",
                            dest="overlap",
                            default=0,
                            type=int)
        parser.add_argument("-t",
                            "--tile_size",
                            dest="tile_size",
                            default=256,
                            type=int)
        parser.add_argument("-z",
                            "--zoom_level",
                            dest="zoom_level",
                            default=0,
                            type=int)
        parser.add_argument("--lr", dest="lr", default=0.01, type=float)
        parser.add_argument("--init_fmaps",
                            dest="init_fmaps",
                            default=16,
                            type=int)
        parser.add_argument("--data_path",
                            "--dpath",
                            dest="data_path",
                            default=os.path.join(str(Path.home()), "tmp"))
        parser.add_argument("-w",
                            "--working_path",
                            "--wpath",
                            dest="working_path",
                            default=os.path.join(str(Path.home()), "tmp"))
        parser.add_argument("-s",
                            "--save_path",
                            dest="save_path",
                            default=os.path.join(str(Path.home()), "tmp"))
        args, _ = parser.parse_known_args(argv)

        os.makedirs(args.save_path, exist_ok=True)
        os.makedirs(args.data_path, exist_ok=True)
        os.makedirs(args.working_path, exist_ok=True)

        # fetch annotations (filter val/test sets + other annotations)
        all_annotations = AnnotationCollection(project=77150529,
                                               showWKT=True,
                                               showMeta=True,
                                               showTerm=True).fetch()
        val_ids = {77150767, 77150761, 77150809}
        test_ids = {77150623, 77150611, 77150755}
        val_test_ids = val_ids.union(test_ids)
        train_collection = all_annotations.filter(lambda a: (
            a.user in {55502856} and len(a.term) > 0 and a.term[0] in
            {35777351, 35777321, 35777459} and a.image not in val_test_ids))
        val_rois = all_annotations.filter(
            lambda a: (a.user in {142954314} and a.image in val_ids and len(
                a.term) > 0 and a.term[0] in {154890363}))
        val_foreground = all_annotations.filter(
            lambda a: (a.user in {142954314} and a.image in val_ids and len(
                a.term) > 0 and a.term[0] in {154005477}))

        train_wsi_ids = list({an.image
                              for an in all_annotations
                              }.difference(val_test_ids))
        val_wsi_ids = list(val_ids)

        download_path = os.path.join(args.data_path,
                                     "crops-{}".format(args.tile_size))
        images = {
            _id: ImageInstance().fetch(_id)
            for _id in (train_wsi_ids + val_wsi_ids)
        }

        train_crops = [
            AnnotationCrop(images[annot.image],
                           annot,
                           download_path,
                           args.tile_size,
                           zoom_level=args.zoom_level)
            for annot in train_collection
        ]
        val_crops = [
            AnnotationCrop(images[annot.image],
                           annot,
                           download_path,
                           args.tile_size,
                           zoom_level=args.zoom_level) for annot in val_rois
        ]

        for crop in train_crops + val_crops:
            crop.download()

        np.random.seed(42)
        dataset = RemoteAnnotationTrainDataset(
            train_crops, seg_trans=segmentation_transform)
        loader = DataLoader(dataset,
                            shuffle=True,
                            batch_size=args.batch_size,
                            num_workers=args.n_jobs,
                            worker_init_fn=worker_init)

        # network
        device = torch.device(args.device)
        unet = Unet(args.init_fmaps, n_classes=1)
        unet.train()
        unet.to(device)

        optimizer = Adam(unet.parameters(), lr=args.lr)
        loss_fn = BCEWithLogitsLoss(reduction="mean")

        results = {
            "train_losses": [],
            "val_losses": [],
            "val_metrics": [],
            "save_path": []
        }

        for e in range(args.epochs):
            print("########################")
            print("        Epoch {}".format(e))
            print("########################")

            epoch_losses = list()
            unet.train()
            for i, (x, y) in enumerate(loader):
                x, y = (t.to(device) for t in [x, y])
                y_pred = unet.forward(x)
                loss = loss_fn(y_pred, y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_losses = [loss.detach().cpu().item()] + epoch_losses[:5]
                print("{} - {:1.5f}".format(i, np.mean(epoch_losses)))
                results["train_losses"].append(epoch_losses[0])

            unet.eval()
            # validation
            val_losses = np.zeros(len(val_rois), dtype=np.float)
            val_roc_auc = np.zeros(len(val_rois), dtype=np.float)
            val_cm = np.zeros([len(val_rois), 2, 2], dtype=np.int)

            for i, roi in enumerate(val_crops):
                foregrounds = find_intersecting_annotations(
                    roi.annotation, val_foreground)
                with torch.no_grad():
                    y_pred, y_true = predict_roi(
                        roi,
                        foregrounds,
                        unet,
                        device,
                        in_trans=transforms.ToTensor(),
                        batch_size=args.batch_size,
                        tile_size=args.tile_size,
                        overlap=args.overlap,
                        n_jobs=args.n_jobs,
                        zoom_level=args.zoom_level)

                val_losses[i] = metrics.log_loss(y_true.flatten(),
                                                 y_pred.flatten())
                val_roc_auc[i] = metrics.roc_auc_score(y_true.flatten(),
                                                       y_pred.flatten())
                val_cm[i] = metrics.confusion_matrix(
                    y_true.flatten().astype(np.uint8),
                    (y_pred.flatten() > 0.5).astype(np.uint8))

            print("------------------------------")
            print("Epoch {}:".format(e))
            val_loss = np.mean(val_losses)
            roc_auc = np.mean(val_roc_auc)
            print("> val_loss: {:1.5f}".format(val_loss))
            print("> roc_auc : {:1.5f}".format(roc_auc))
            cm = np.sum(val_cm, axis=0)
            cnt = np.sum(val_cm)
            print("CM at 0.5 threshold")
            print("> {:3.2f}%  {:3.2f}%".format(100 * cm[0, 0] / cnt,
                                                100 * cm[0, 1] / cnt))
            print("> {:3.2f}%  {:3.2f}%".format(100 * cm[1, 0] / cnt,
                                                100 * cm[1, 1] / cnt))
            print("------------------------------")

            filename = "{}_e_{}_val_{:0.4f}_roc_{:0.4f}_z{}_s{}.pth".format(
                datetime.now().timestamp(), e, val_loss, roc_auc,
                args.zoom_level, args.tile_size)
            torch.save(unet.state_dict(), os.path.join(args.save_path,
                                                       filename))

            results["val_losses"].append(val_loss)
            results["val_metrics"].append(roc_auc)
            results["save_path"].append(filename)

        return results
Ejemplo n.º 7
0
    parser.add_argument('--load', type=str, default='CP_epoch20.pth',
                        help='Load model')
    parser.add_argument('--test', type=str, default=None,
                        help='Load test dataset')
    parser.add_argument('--no-train',dest='train', action='store_false', help='Skip the training phase')
    parser.set_defaults(train=True)
    return parser.parse_args()



if __name__=='__main__':
    args = get_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
   
    model = Unet(n_channels=3, n_classes=1)
    model.to(device=device)
    

    train_data, test_data, maps = load_data(organ=args.organ)        

    if args.train:
        train_net(net=model, body_array_new=train_data, maps_list=maps, epochs=args.epochs, lr=args.lr, device=device, batch_size=args.batchsize)
    
        
    model_eval = Unet(n_channels=3, n_classes=1)
    model_eval.load_state_dict(torch.load(dir_checkpoint+args.load))
    model_eval.to(device=device)
    eval_net(model=model_eval,body_array_new=test_data, maps_list=maps, device=device)
    
   
Ejemplo n.º 8
0
                                            classes=['bedroom_train'],
                                            transform=transform)
        val_dataset = torchvision.datasets.LSUN(args.val_data,
                                                classes=['bedroom_val'],
                                                transform=transform)

    if args.mini_eval:
        if args.dataset == 'folder':
            val_dataset, _ = torch.utils.data.random_split(
                val_dataset,
                [args.mini_eval,
                 len(val_dataset) - args.mini_eval])
        else:
            dataset, val_dataset = torch.utils.data.random_split(
                dataset, [len(dataset) - args.mini_eval, args.mini_eval])

    dataloader = torch.utils.data.DataLoader(dataset,
                                             shuffle=True,
                                             batch_size=args.batch_size,
                                             drop_last=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 shuffle=True,
                                                 batch_size=args.batch_size,
                                                 drop_last=True)
    mask_generator = MaskGenerator(args.input_size, file_path=args.mask)
    maskloader = torch.utils.data.DataLoader(mask_generator,
                                             shuffle=False,
                                             batch_size=args.batch_size)
    maskiter = iter(maskloader)
    train(model.to(device), criterion, dataloader, maskloader, val_dataloader)