Ejemplo n.º 1
0
def main(args: Namespace):
    pl.seed_everything(42)

    min_val, max_val, scale_val = -1500, 3000, 1000

    print("- Loading dataloader")
    datamodule = CTDataModule(path=args.dataset_path,
                              train_frac=1,
                              batch_size=1,
                              num_workers=0,
                              rescale_input=args.rescale_input)
    datamodule.setup()
    train_loader = datamodule.train_dataloader()

    print("- Loading single CT sample")
    single_sample, _ = next(iter(train_loader))
    single_sample = single_sample.cuda()

    print("- Loading model weights")
    model = VQVAE.load_from_checkpoint(str(args.ckpt_path)).cuda()

    print("- Performing forward pass")
    with torch.no_grad(), torch.cuda.amp.autocast():
        res, *_ = model(single_sample)
        res = torch.nn.functional.elu(res)

    res = res.squeeze().detach().cpu().numpy()
    res = res * scale_val - scale_val
    res = np.rint(res).astype(np.int)

    print("- Writing to nrrd")
    nrrd.write(str(args.out_path), res, header={'spacings': (0.976, 0.976, 3)})

    print("- Done")
Ejemplo n.º 2
0
def main(args):
    if args.checkpoint_path is not None:
        model = VQVAE.load_from_checkpoint(str(args.checkpoint_path))
    else:
        model = VQVAE()

    datamodule = CTDataModule(
        path=args.dataset_path,
        batch_size=1,
        train_frac=1,
        num_workers=5,
        rescale_input=(256,256,128)
    )
    datamodule.setup()
    dataloader = datamodule.train_dataloader()

    db = lmdb.open(
        get_output_abspath(args.checkpoint_path, args.output_path, args.output_name),
        map_size=int(1e12),
        max_dbs=model.n_bottleneck_blocks
    )

    sub_dbs = [db.open_db(str(i).encode()) for i in range(model.n_bottleneck_blocks)]
    with db.begin(write=True) as txn:
        # Write root db metadata
        txn.put(b"num_dbs", str(model.n_bottleneck_blocks).encode())
        txn.put(b"length",  str(len(dataloader)).encode())
        txn.put(b"num_embeddings", pickle.dumps(np.asarray(model.num_embeddings)))

        for i, sample_encodings in tqdm(enumerate(extract_samples(model, dataloader)), total=len(dataloader)):
            for sub_db, encoding in zip(sub_dbs, sample_encodings):
                txn.put(str(i).encode(), pickle.dumps(encoding.cpu().numpy()), db=sub_db)

    db.close()
Ejemplo n.º 3
0
def main(args: Namespace):

    min_val, max_val, scale_val = -1500, 3000, 1000

    print("- Loading model weights")
    model = VQVAE.load_from_checkpoint(str(args.ckpt_path)).cuda()

    db = torch.load(args.db_path)

    for embedding_0_key, embedding_0 in db[0].items():
        embedding_1_key = embedding_0['condition']
        embedding_1 = db[1][embedding_1_key]

        # issue where the pixelcnn samples 0's
        success = 'failure' if torch.all(
            embedding_0['data'][-1] == 0) else 'success'

        embeddings = [
            quantizer.embed_code(embedding['data'].cuda().unsqueeze(
                dim=0)).permute(0, 4, 1, 2, 3)
            for embedding, quantizer in zip((
                embedding_0, embedding_1), model.encoder.quantize)
        ]

        print("- Performing forward pass")
        with torch.cuda.amp.autocast():
            res = model.decode(embeddings)
            res = torch.nn.functional.elu(res)

        res = res.squeeze().detach().cpu().numpy()
        res = res * scale_val - scale_val
        res = np.rint(res).astype(np.int)

        print("- Writing to nrrd")
        nrrd.write(
            str(args.out_path) +
            f'_{success}_{str(embedding_1_key)}_{str(embedding_0_key)}.nrrd',
            res,
            header={'spacings': (0.976, 0.976, 3)})

        print("- Done")
def main(args: Namespace):
    # Same seed as used in train.py, so that train/val splits are also the same
    pl.trainer.seed_everything(seed=42)

    print("- Loading datamodule")
    datamodule = CTDataModule(path=args.dataset_path,
                              batch_size=5,
                              num_workers=5)  # mypy: ignore
    datamodule.setup()

    train_dl = datamodule.train_dataloader()
    val_dl = datamodule.val_dataloader()

    print("- Loading model weights")
    model = VQVAE.load_from_checkpoint(str(args.ckpt_path)).cuda()

    data_min, data_max = -0.24, 4
    data_range = data_max - data_min

    train_ssim = SSIM3DSlices(data_range=data_range)
    val_ssim = SSIM3DSlices(data_range=data_range)

    def batch_ssim(batch, ssim_f):
        batch = batch.cuda()
        out, *_ = model(batch)
        out = F.elu(out)
        return val_ssim(out.float(), batch)

    with torch.no_grad(), torch.cuda.amp.autocast():
        val_ssims = torch.Tensor(
            [batch_ssim(batch, ssim_f=val_ssim) for batch, _ in tqdm(val_dl)])
        breakpoint()
        train_ssims = torch.Tensor([
            batch_ssim(batch, ssim_f=train_ssim) for batch, _ in tqdm(train_dl)
        ])

    # breakpoint for manual decision what to do with train_ssims/val_ssims
    # TODO: find some better solution to described above
    breakpoint()