Esempio n. 1
0
def main():
    args = build_parser().parse_args()
    image_size = [args.img_height, args.img_width]
    # config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction = 1.0
    # sess = tf.Session(config=config)
    sess = tf.Session()
    unet = Unet(input_shape=image_size,
                sess=sess,
                filter_num=args.filter_num,
                batch_norm=args.batch_norm)
    unet.build_net()
    if args.checkpoint_path:
        unet.load_weights(args.checkpoint_path)

    images, masks = read_data(args.train_dir,
                              args.train_mask_dir,
                              n_images=args.n_images,
                              image_size=image_size)
    val_images, val_masks = read_data(args.val_dir,
                                      args.val_mask_dir,
                                      n_images=args.n_images // 4,
                                      image_size=image_size)
    unet.train(images=images,
               masks=masks,
               val_images=val_images,
               val_masks=val_masks,
               epochs=args.epochs,
               batch_size=args.batch_size,
               learning_rate=args.learning_rate,
               dice_loss=args.dice_loss,
               always_save=args.always_save)
Esempio n. 2
0
def train():
    model = Unet(5, 2).to(device)
    model.train()
    batch_size = args.batch_size
    criterion = torch.nn.BCELoss()
    optimizer = optim.Adam(model.parameters())
    PAVE_dataset = SSFPDataset("train", transform=1, target_transform=1)
    dataloaders = DataLoader(PAVE_dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=0)
    train_model(model, criterion, optimizer, dataloaders)
Esempio n. 3
0
        itt = itt + 1
        scheduler_unet_qpi2dapi = optim.lr_scheduler.StepLR(
            optimizer_unet_qpi2dapi, 2000, gamma=0.1, last_epoch=-1)
        scheduler_unet_dapi2qpi = optim.lr_scheduler.StepLR(
            optimizer_unet_dapi2qpi, 2000, gamma=0.1, last_epoch=-1)
        scheduler_D_dapi = optim.lr_scheduler.StepLR(optimizer_D_dapi,
                                                     2000,
                                                     gamma=0.1,
                                                     last_epoch=-1)
        scheduler_D_qpi = optim.lr_scheduler.StepLR(optimizer_D_qpi,
                                                    2000,
                                                    gamma=0.1,
                                                    last_epoch=-1)

        unet_qpi2dapi.train()
        D_dapi.train()
        unet_dapi2qpi.train()
        D_qpi.train()
        #
        #        for p in unet_qpi2dapi.parameters():
        #            p.requires_grad = False
        #        for p in unet_dapi2qpi.parameters():
        #            p.requires_grad = False
        #        for p in D_dapi.parameters():
        #            p.requires_grad = True
        #        for p in D_qpi.parameters():
        #            p.requires_grad = True
        #
        #        for t in range(n_critic):
        #
Esempio n. 4
0
        train_iters = []
        test_iters = [0]
        stop = 0
        itt = 0
        start_sl = 0

        stop = 0
        itt = 0
        while itt < iterace and stop == 0:

            itt = itt + 1

            scheduler1.step()
            scheduler2.step()

            unet.train()
            D.train()

            for p in unet.parameters():
                p.requires_grad = False

            for p in D.parameters():
                p.requires_grad = True

            if alpha > 0:
                for t in range(n_critic):

                    (in_images, out_images, pat) = next(gen)

                    in_images = in_images.cuda(0)
                    out_images = out_images.cuda(0)
Esempio n. 5
0
        model.eval()
        fixed_pred,_ = model(fixed_masked, fixed_mask)
        torchvision.utils.save_image(fixed_pred,os.path.join(results,f'pred_{i}.jpg'),normalize=True)
        print(fixed_pred.min(), fixed_pred.max(), fixed_pred.mean())
        print(loss.item())

        if not training:
            print(fixed_orig.min(), fixed_orig.max(), fixed_orig.mean())
            print(fixed_masked.min(), fixed_masked.max(), fixed_masked.mean())
            print(fixed_mask.min(), fixed_mask.max(), fixed_mask.mean())
            torchvision.utils.save_image(fixed_masked, os.path.join(results, f'masked_{i}.jpg'),normalize=True)
            torchvision.utils.save_image(fixed_orig, os.path.join(results, f'orig_{i}.jpg'),normalize=True)
save_image(0,fixed_masked, fixed_mask,False,torch.tensor(0))

for i in range(200):
    i += 1
    print(f'{i}..')
    for [masked, mask], orig in dl:
        model.train()
        masked = masked.to(device)
        mask = mask.to(device)
        orig = orig.to(device)
        optimizer.zero_grad()
        pred, _ = model(masked, mask)
        loss = criterion(orig,mask,pred)
        loss.backward()
        optimizer.step()

    save_image(i,fixed_masked, fixed_mask,True, loss)

Esempio n. 6
0
def main(argv):
    """

    IMAGES VALID:
    * 005-TS_13C08351_2-2014-02-12 12.22.44.ndpi | id : 77150767
    * 024-12C07162_2A-2012-08-14-17.21.05.jp2 | id : 77150761
    * 019-CP_12C04234_2-2012-08-10-12.49.26.jp2 | id : 77150809

    IMAGES TEST:
    * 004-PF_08C11886_1-2012-08-09-19.05.53.jp2 | id : 77150623
    * 011-TS_13C10153_3-2014-02-13 15.22.21.ndpi | id : 77150611
    * 018-PF_07C18435_1-2012-08-17-00.55.09.jp2 | id : 77150755

    """
    with Cytomine.connect_from_cli(argv):
        parser = ArgumentParser()
        parser.add_argument("-b",
                            "--batch_size",
                            dest="batch_size",
                            default=4,
                            type=int)
        parser.add_argument("-j",
                            "--n_jobs",
                            dest="n_jobs",
                            default=1,
                            type=int)
        parser.add_argument("-e",
                            "--epochs",
                            dest="epochs",
                            default=1,
                            type=int)
        parser.add_argument("-d", "--device", dest="device", default="cpu")
        parser.add_argument("-o",
                            "--overlap",
                            dest="overlap",
                            default=0,
                            type=int)
        parser.add_argument("-t",
                            "--tile_size",
                            dest="tile_size",
                            default=256,
                            type=int)
        parser.add_argument("-z",
                            "--zoom_level",
                            dest="zoom_level",
                            default=0,
                            type=int)
        parser.add_argument("--lr", dest="lr", default=0.01, type=float)
        parser.add_argument("--init_fmaps",
                            dest="init_fmaps",
                            default=16,
                            type=int)
        parser.add_argument("--data_path",
                            "--dpath",
                            dest="data_path",
                            default=os.path.join(str(Path.home()), "tmp"))
        parser.add_argument("-w",
                            "--working_path",
                            "--wpath",
                            dest="working_path",
                            default=os.path.join(str(Path.home()), "tmp"))
        parser.add_argument("-s",
                            "--save_path",
                            dest="save_path",
                            default=os.path.join(str(Path.home()), "tmp"))
        args, _ = parser.parse_known_args(argv)

        os.makedirs(args.save_path, exist_ok=True)
        os.makedirs(args.data_path, exist_ok=True)
        os.makedirs(args.working_path, exist_ok=True)

        # fetch annotations (filter val/test sets + other annotations)
        all_annotations = AnnotationCollection(project=77150529,
                                               showWKT=True,
                                               showMeta=True,
                                               showTerm=True).fetch()
        val_ids = {77150767, 77150761, 77150809}
        test_ids = {77150623, 77150611, 77150755}
        val_test_ids = val_ids.union(test_ids)
        train_collection = all_annotations.filter(lambda a: (
            a.user in {55502856} and len(a.term) > 0 and a.term[0] in
            {35777351, 35777321, 35777459} and a.image not in val_test_ids))
        val_rois = all_annotations.filter(
            lambda a: (a.user in {142954314} and a.image in val_ids and len(
                a.term) > 0 and a.term[0] in {154890363}))
        val_foreground = all_annotations.filter(
            lambda a: (a.user in {142954314} and a.image in val_ids and len(
                a.term) > 0 and a.term[0] in {154005477}))

        train_wsi_ids = list({an.image
                              for an in all_annotations
                              }.difference(val_test_ids))
        val_wsi_ids = list(val_ids)

        download_path = os.path.join(args.data_path,
                                     "crops-{}".format(args.tile_size))
        images = {
            _id: ImageInstance().fetch(_id)
            for _id in (train_wsi_ids + val_wsi_ids)
        }

        train_crops = [
            AnnotationCrop(images[annot.image],
                           annot,
                           download_path,
                           args.tile_size,
                           zoom_level=args.zoom_level)
            for annot in train_collection
        ]
        val_crops = [
            AnnotationCrop(images[annot.image],
                           annot,
                           download_path,
                           args.tile_size,
                           zoom_level=args.zoom_level) for annot in val_rois
        ]

        for crop in train_crops + val_crops:
            crop.download()

        np.random.seed(42)
        dataset = RemoteAnnotationTrainDataset(
            train_crops, seg_trans=segmentation_transform)
        loader = DataLoader(dataset,
                            shuffle=True,
                            batch_size=args.batch_size,
                            num_workers=args.n_jobs,
                            worker_init_fn=worker_init)

        # network
        device = torch.device(args.device)
        unet = Unet(args.init_fmaps, n_classes=1)
        unet.train()
        unet.to(device)

        optimizer = Adam(unet.parameters(), lr=args.lr)
        loss_fn = BCEWithLogitsLoss(reduction="mean")

        results = {
            "train_losses": [],
            "val_losses": [],
            "val_metrics": [],
            "save_path": []
        }

        for e in range(args.epochs):
            print("########################")
            print("        Epoch {}".format(e))
            print("########################")

            epoch_losses = list()
            unet.train()
            for i, (x, y) in enumerate(loader):
                x, y = (t.to(device) for t in [x, y])
                y_pred = unet.forward(x)
                loss = loss_fn(y_pred, y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_losses = [loss.detach().cpu().item()] + epoch_losses[:5]
                print("{} - {:1.5f}".format(i, np.mean(epoch_losses)))
                results["train_losses"].append(epoch_losses[0])

            unet.eval()
            # validation
            val_losses = np.zeros(len(val_rois), dtype=np.float)
            val_roc_auc = np.zeros(len(val_rois), dtype=np.float)
            val_cm = np.zeros([len(val_rois), 2, 2], dtype=np.int)

            for i, roi in enumerate(val_crops):
                foregrounds = find_intersecting_annotations(
                    roi.annotation, val_foreground)
                with torch.no_grad():
                    y_pred, y_true = predict_roi(
                        roi,
                        foregrounds,
                        unet,
                        device,
                        in_trans=transforms.ToTensor(),
                        batch_size=args.batch_size,
                        tile_size=args.tile_size,
                        overlap=args.overlap,
                        n_jobs=args.n_jobs,
                        zoom_level=args.zoom_level)

                val_losses[i] = metrics.log_loss(y_true.flatten(),
                                                 y_pred.flatten())
                val_roc_auc[i] = metrics.roc_auc_score(y_true.flatten(),
                                                       y_pred.flatten())
                val_cm[i] = metrics.confusion_matrix(
                    y_true.flatten().astype(np.uint8),
                    (y_pred.flatten() > 0.5).astype(np.uint8))

            print("------------------------------")
            print("Epoch {}:".format(e))
            val_loss = np.mean(val_losses)
            roc_auc = np.mean(val_roc_auc)
            print("> val_loss: {:1.5f}".format(val_loss))
            print("> roc_auc : {:1.5f}".format(roc_auc))
            cm = np.sum(val_cm, axis=0)
            cnt = np.sum(val_cm)
            print("CM at 0.5 threshold")
            print("> {:3.2f}%  {:3.2f}%".format(100 * cm[0, 0] / cnt,
                                                100 * cm[0, 1] / cnt))
            print("> {:3.2f}%  {:3.2f}%".format(100 * cm[1, 0] / cnt,
                                                100 * cm[1, 1] / cnt))
            print("------------------------------")

            filename = "{}_e_{}_val_{:0.4f}_roc_{:0.4f}_z{}_s{}.pth".format(
                datetime.now().timestamp(), e, val_loss, roc_auc,
                args.zoom_level, args.tile_size)
            torch.save(unet.state_dict(), os.path.join(args.save_path,
                                                       filename))

            results["val_losses"].append(val_loss)
            results["val_metrics"].append(roc_auc)
            results["save_path"].append(filename)

        return results
Esempio n. 7
0
it = -1
for epoch in range(1000):
    for k, (data, lbl) in enumerate(trainloader):
        it += 1
        print(it)

        data = data.cuda()
        lbl = lbl.cuda()

        data.requires_grad = True
        lbl.requires_grad = True

        optimizer.zero_grad()  # zero the gradient buffers

        net.train()

        output = net(data)
        output = F.sigmoid(output)

        loss = dice_loss(output, lbl)  ### tady spočíta MSE pro denoising....

        loss.backward()  ## claculate gradients
        optimizer.step()  ## update parametrs

        clas = (output > 0.5).float()

        acc = torch.mean((clas == lbl).float())

        train_acc_tmp.append(acc.detach().cpu().numpy())
        train_loss_tmp.append(loss.detach().cpu().numpy())