Exemple #1
0
def direct_project(y, normalizationmode, nostd=False, anon=False):
    train_ids, _ = create_train_idxs(395)
    if anon:
        mask = nb.load("mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c_headmask_defaced.nii").get_fdata()
    else:
        mask = nb.load("mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c_headmask.nii").get_fdata()
    mask = mask[:192,2:226,:192][None,None,:,:,:]
    mask = torch.Tensor(mask)
    
    with h5py.File(f"icov_x_{normalizationmode}.h5", "r") as hfout:
        mean = np.moveaxis(hfout["mean"][:],3,0)[None,:,:,:,:].astype(np.float32)
        std = np.sqrt(np.moveaxis(hfout["var"][:],3,0))[None,:,:,:,:].astype(np.float32)
    y_z = (y-mean)/(1.0 if nostd else std)
    residual = deepcopy(y_z)
    with h5py.File(f"x_{normalizationmode}.h5", "r") as hf:
        x = hf['x']
        for ii,i in enumerate(train_ids):
            print(f"{ii},{i}")
            xi_z = torch.Tensor((x[i].astype(np.float32)-mean)/(1.0 if nostd else std))
            xi_norm = torch.norm(xi_z*mask)
            coef = (residual*xi_z*mask).sum()/xi_norm
            print(coef)
            residual = residual - coef * xi_z/xi_norm
            #print(np.linalg.norm(residual))
    return y_z - residual, residual
Exemple #2
0
def compute_kernelrow(y, normalizationmode, nostd=False):
    mask = nb.load("mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c_headmask.nii").get_fdata()
    mask = mask[:192,2:226,:192][None,None,:,:,:]
    mask = torch.Tensor(mask)
    Nmask = mask.sum()
    train_ids, _ = create_train_idxs(395)
    with h5py.File(f"icov_x_{normalizationmode}.h5", "r") as hfout:
        mean = np.moveaxis(hfout["mean"][:],3,0)[None,:,:,:,:]
        std = np.sqrt(np.moveaxis(hfout["var"][:],3,0))[None,:,:,:,:]
    yz = (y-mean)
    if not nostd:
        yz /= std
    with h5py.File(f"x_{normalizationmode}.h5", "r") as hf:
        x = hf['x']
        kmeanrow = torch.zeros(1,1,1,1,len(train_ids))
        for i in train_ids:
            print(i)
            xi = torch.Tensor((x[i].astype(np.float32)-mean)/(1.0 if nostd else std))
            kmeanrow[0,0,0,0,i] = torch.sum((yz*xi)*mask/Nmask)
            print(kmeanrow[0,0,0,0,i])
    return kmeanrow
Exemple #3
0
def pca_project(y, normalizationmode, nostd=False):
    train_ids, _ = create_train_idxs(395)
    #train_ids = train_ids[:10]
    #kernelrow = compute_kernelrow(y, normalizationmode, nostd)
    kernelrow = compute_kernelrow_local(y, normalizationmode, nostd)
    print(kernelrow)
    with h5py.File(f"localkernel_{normalizationmode}{'_nostd' if nostd else ''}.h5", "r") as hfout:
        kernel = torch.Tensor(hfout['kernel'][0,0,0])
    kernel = kernel[train_ids,:][:,train_ids]
    eps = kernel.diagonal().mean()/10
    coefrow = torch.matmul(kernelrow[0,0,0], torch.inverse(kernel + eps*torch.eye(kernel.shape[-1])))
    projection = 0*y
    with h5py.File(f"icov_x_{normalizationmode}.h5", "r") as hfout:
        mean = np.moveaxis(hfout["mean"][:],3,0)[None,:,:,:,:]
        std = np.sqrt(np.moveaxis(hfout["var"][:],3,0))[None,:,:,:,:]
    y_z = (y-mean)/(1.0 if nostd else std)
    with h5py.File(f"x_{normalizationmode}.h5", "r") as hf:
        x = hf['x']
        for ii,i in enumerate(train_ids):
            print(coefrow[0,ii])
            print(ii)
            projection += coefrow[0,ii] * (x[i].astype(np.float32)-mean)/(1.0 if nostd else std)
    return projection, y_z - projection
Exemple #4
0
def pca_project_test(test_ids, cuda=False):
    train_ids, _ = create_train_idxs(395)
    with h5py.File(f"localkernel_z.h5", "r") as hfout:
        kernel = torch.Tensor(hfout['kernel'][0,0,0])
        if cuda: kernel.cuda()
    kernelrows = kernel[test_ids,:][:,train_ids]
    kernel = kernel[train_ids,:][:,train_ids]
    eps = kernel.diagonal().mean()/10
    epseye = torch.eye(kernel.shape[-1])
    if cuda: epseye.cuda()
    coefrows = torch.matmul(kernelrows, torch.inverse(kernel + eps*epseye))
    print(coefrows)
    residuals = torch.zeros(len(test_ids),9,192,224,192)
    if cuda: residuals.cuda()
    with h5py.File(f"icov_x_z.h5", "r") as hfout:
        mean = torch.Tensor(np.moveaxis(hfout["mean"][:],3,0)[None,:,:,:,:])
        std = torch.Tensor(np.sqrt(np.moveaxis(hfout["var"][:],3,0))[None,:,:,:,:])
        if cuda:
            mean.cuda()
            std.cuda()
    print("loading h5 file")
    with h5py.File(f"x_z.h5", "r") as hf:
        x = hf['x']
        print("loading test data")
        for ii,i in enumerate(test_ids):
            print(ii)
            xi = torch.Tensor(x[i].astype(np.float32))
            if cuda: xi.cuda()
            residuals[ii] = (xi-mean)/std
        print("updating residuals")
        for ii,i in enumerate(train_ids):
            xi = torch.Tensor(x[i].astype(np.float32))
            if cuda: xi.cuda()
            print(ii)
            residuals -= coefrows[:,ii].view(-1,1,1,1,1) * (xi-mean)/std
            #print((residuals**2).sum())
    return residuals
Exemple #5
0
def compute_kernelrow_local(y, normalizationmode, nostd=False, anon=False):
    if anon:
        mask = nb.load("mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c_headmask_defaced.nii").get_fdata()
    else:
        mask = nb.load("mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c_headmask.nii").get_fdata()
    mask = mask[:192,2:226,:192][None,None,:,:,:]
    nmask = mask.sum()
    mask = torch.Tensor(mask)
    train_ids, _ = create_train_idxs(395)
    #train_ids = train_ids[:10]
    with h5py.File(f"icov_x_{normalizationmode}.h5", "r") as hfout:
        mean = np.moveaxis(hfout["mean"][:],3,0)[None,:,:,:,:]
        std = np.sqrt(np.moveaxis(hfout["var"][:],3,0))[None,:,:,:,:]
    y_z = (y-mean)/(1.0 if nostd else std)
    with h5py.File(f"x_{normalizationmode}.h5", "r") as hf:
        x = hf['x']
        n = len(train_ids)
        chunks = x.chunks
        kmeanrow = torch.zeros(1,1,1,1,n)
        Nmask = mask.sum()
        for i in range(x.shape[2]//x.chunks[2]):
            for j in range(x.shape[3]//x.chunks[3]):
                for k in range(x.shape[4]//x.chunks[4]):
                    print([i,j,k])
                    maxi = (i+1)*chunks[2]
                    maxj = (j+1)*chunks[3]
                    maxk = (k+1)*chunks[4]

                    nmask = mask[:,:,
                        i*x.chunks[2]:maxi,
                        j*x.chunks[3]:maxj,
                        k*x.chunks[4]:maxk
                        ].sum()
                    if nmask == 0:
                        continue
                    r = np.concatenate([x[ii:ii+1,:,
                        i*x.chunks[2]:maxi,
                        j*x.chunks[3]:maxj,
                        k*x.chunks[4]:maxk
                        ] for ii in train_ids], axis=0)
                    r = r.astype(np.float32)
                    ry = y_z[:,:,
                            i*x.chunks[2]:maxi,
                            j*x.chunks[3]:maxj,
                            k*x.chunks[4]:maxk
                        ]
                    r = torch.Tensor(r)
                    #mean = r.mean(dim=0, keepdim=True)
                    r -= mean[:,:,
                            i*x.chunks[2]:maxi,
                            j*x.chunks[3]:maxj,
                            k*x.chunks[4]:maxk
                        ]
                    #ry -= mean
                    #std = r.std(dim=0, keepdim=True)
                    if not nostd:
                        r /= std[:,:,
                                i*x.chunks[2]:maxi,
                                j*x.chunks[3]:maxj,
                                k*x.chunks[4]:maxk
                            ]
                    #ry /= std
                    r = r*mask[:,:,
                        i*x.chunks[2]:maxi,
                        j*x.chunks[3]:maxj,
                        k*x.chunks[4]:maxk
                        ]
                    ry = ry*mask[:,:,
                        i*x.chunks[2]:maxi,
                        j*x.chunks[3]:maxj,
                        k*x.chunks[4]:maxk
                        ]
                    r = r.view(r.shape[0],-1)
                    ry = ry.view(ry.shape[0],-1)
                
                    localkernelrow = torch.matmul(ry, torch.transpose(r,0,1)).view(1,1,1,1,n)/nmask
                    kmeanrow += localkernelrow*nmask/Nmask
                    print(kmeanrow[0,0,0,0,0])
    return kmeanrow
Exemple #6
0
def main(args):
    # setup MRI data loaders

    train_ids, val_ids = create_train_idxs(395)

    if args.voxnorm:
        h5icov = "icov_x_z_anon.h5" if args.anon else "icov_x_z.h5"
        # if not os.path.exists("icov_x_z.h5"):
        #     print("calculating voxel stats")
        #     calc_icov(train_ids)
        #     print("done")
        with h5py.File(h5icov, "r") as hfstat:
            voxmean = torch.Tensor(np.moveaxis(hfstat["mean"][:], 3, 0))
            voxstd = torch.sqrt(
                torch.Tensor(np.moveaxis(hfstat["var"][:], 3, 0)))
        if args.cuda:
            voxmean = voxmean.cuda()
            voxstd = voxstd.cuda()
    if args.anon:
        mask = nb.load(
            "mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c_headmask_defaced.nii"
        ).get_fdata()
    else:
        mask = nb.load(
            "mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c_headmask.nii"
        ).get_fdata()
    if args.cuda:
        mask = torch.Tensor(mask).cuda()
    else:
        mask = torch.Tensor(mask)
    mask = mask[:192, 2:226, :192]

    params = {'batch_size': args.batch, 'shuffle': False, 'num_workers': 0}
    # train_loader, test_loader
    h5file = "x_z_anon.h5" if args.anon else "x_z.h5"
    training_set = Dataset3dH5(train_ids, h5file=h5file)
    train_loader = data.DataLoader(training_set, **params)
    testing_set = Dataset3dH5(val_ids, h5file=h5file)
    test_loader = data.DataLoader(testing_set, **params)

    if args.cuda:
        device = 'cuda'
    else:
        device = 'cpu'

    # setup the VAE
    vae = VAE(z_dim=args.num_zdim,
              use_cuda=args.cuda,
              use_resblocks=args.resblocks,
              n_downsample=args.ndownsample,
              maxchannel=args.maxchannels,
              zchannel=16,
              variational=not args.mse)

    # setup the optimizer
    optimizer = optim.Adam(vae.parameters(),
                           lr=args.learning_rate,
                           weight_decay=args.weightdecay)
    #optimizer = optim.SGD(vae.parameters(), lr=args.learning_rate, weight_decay=0.01)

    train_elbo = []
    test_elbo = []

    startepoch = 0
    argshash = f"{args.num_zdim}_{args.maxchannels}_{args.ndownsample}{'_mse' if args.mse else '_vae'}{'_l1' if args.l1loss else ''}{'_rb' if args.resblocks else ''}{'voxnorm' if args.voxnorm else ''}{'_maskloss' if args.maskloss else ''}{'_addmin' if args.addmin else ''}"
    checkpointpath = f'checkpoint/vae_{argshash}.pt'
    if os.path.exists(checkpointpath):
        print(f"loading {checkpointpath}")
        state = torch.load(checkpointpath, map_location=torch.device(device))
        vae.load_state_dict(state['model'])
        for g in state["optimizer"]["param_groups"]:
            g['lr'] = args.learning_rate
        optimizer.load_state_dict(state["optimizer"])
        train_elbo = state["train_elbo"]
        test_elbo = state["test_elbo"]
        startepoch = len(train_elbo)
        del state['model']
        del state['optimizer']

    # training loop
    for epoch in range(startepoch, args.num_epochs):
        # initialize loss accumulator
        epoch_loss = 0.
        # do a training epoch over each mini-batch x returned
        # by the data loader
        count = 0
        for (x, _) in tqdm(train_loader):
            # if on GPU put mini-batch into CUDA memory
            if args.cuda:
                x = x.cuda()

            if args.voxnorm:
                x = (x - voxmean) / voxstd * mask
            else:
                x = x * mask
                if args.addmin:
                    x = x - x.min()
            optimizer.zero_grad()
            if not args.mse:
                recon, mu, logvar = vae(x)
                if args.maskloss:
                    recon = recon * mask
                loss, mse = loss_function(recon,
                                          x,
                                          mu,
                                          logvar,
                                          mse=not args.l1loss)
            else:
                recon = vae(x)
                if args.maskloss:
                    recon = recon * mask
                loss = mseloss(recon, x)
                mse = loss

            # do ELBO gradient and accumulate loss
            loss.backward()
            optimizer.step()

            #total_norm = np.sum([p.grad.data.norm(2).item() for p in vae.parameters()])
            #print(f"gradient norm: {total_norm}")

            steploss = loss.item()
            count += 1
            epoch_loss += steploss

            if args.verbose:
                print(
                    f"loss for step: {steploss}, mse: {mse}, running mean: {epoch_loss/count}, previous epochs: {np.mean(train_elbo)}"
                )

        # report training diagnostics
        normalizer_train = len(train_loader.dataset)
        total_epoch_loss_train = epoch_loss / normalizer_train * args.batch
        train_elbo.append(total_epoch_loss_train)
        print("[epoch %03d]  average training loss: %.4f" %
              (epoch, total_epoch_loss_train))

        checkpointpath = f'checkpoint/vae_{argshash}.pt'

        if epoch % args.test_frequency == 0 and len(test_loader) > 0:
            # initialize loss accumulator
            test_loss = 0.
            # compute the loss over the entire test set
            #for i, x in enumerate(test_loader):
            with torch.no_grad():
                for (x, _) in test_loader:
                    # if on GPU put mini-batch into CUDA memory
                    if args.cuda:
                        x = x.cuda()
                    if args.voxnorm:
                        x = (x - voxmean) / voxstd * mask
                    else:
                        x = x * mask
                        if args.addmin:
                            x = x - x.min()
                    # compute ELBO estimate and accumulate loss

                    if not args.mse:
                        mu, logvar = vae.encode(x)
                        recon = vae.decode(mu)
                        if args.maskloss:
                            recon = recon * mask
                        loss, mse = loss_function(recon,
                                                  x,
                                                  mu,
                                                  logvar,
                                                  mse=not args.l1loss)
                    else:
                        recon = vae(x)
                        if args.maskloss:
                            recon = recon * mask
                        loss = mseloss(recon, x)
                    test_loss += loss.item()

            # report test diagnostics
            normalizer_test = len(test_loader.dataset)
            total_epoch_loss_test = test_loss / normalizer_test * args.batch
            test_elbo.append(total_epoch_loss_test)
            print("[epoch %03d]  average test loss: %.4f" %
                  (epoch, total_epoch_loss_test))

        torch.save(
            {
                "args": args,
                "model": vae.state_dict(),
                "optimizer": optimizer.state_dict(),
                "epoch": epoch,
                "train_elbo": train_elbo,
                "test_elbo": test_elbo
            }, checkpointpath)
        if min(test_elbo) == test_elbo[-1]:
            torch.save(
                {
                    "args": args,
                    "model": vae.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "epoch": epoch,
                    "train_elbo": train_elbo,
                    "test_elbo": test_elbo
                }, checkpointpath + '_best.pt')

    return vae
Exemple #7
0
                        ic = ic[:, :, :, modalityidx, :]
                        ic = torch.Tensor(ic)
                    mean0 = mean[i * c[0]:maxi, j * c[1]:maxj, k * c[2]:maxk]
                    mean0 = mean0[:, :, :, modalityidx]
                    mean0 = torch.Tensor(mean0).unsqueeze(4)
                    if use_cuda:
                        ic = ic.cuda()
                        out = out.cuda()
                        mean0 = mean0.cuda()
                    xin = x[0, :, i * c[0]:maxi, j * c[1]:maxj,
                            k * c[2]:maxk].permute(1, 2, 3, 0).unsqueeze(4)
                    xin -= mean0
                    out[0, 0, i * c[0]:maxi, j * c[1]:maxj,
                        k * c[2]:maxk] = torch.matmul(
                            torch.matmul(ic, xin).permute(0, 1, 2, 4, 3),
                            xin).squeeze()
        return out, xin.shape[3]


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument('-i',
                        '--input',
                        default="x_z.h5",
                        type=str,
                        help='input h5 file')
    args = parser.parse_args()
    print(args)

    train_ids, _ = create_train_idxs(395)
    calc_icov(train_ids, filename=args.input)
Exemple #8
0
def main(args):
    _, val_ids = create_train_idxs(395)

    with h5py.File(f"icov_x_z.h5", "r") as hfstat:
        voxmean = torch.Tensor(np.moveaxis(hfstat["mean"][:], 3, 0))
        voxstd = torch.sqrt(torch.Tensor(np.moveaxis(hfstat["var"][:], 3, 0)))

    if args.anon:
        mask = nb.load(
            "mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c_headmask_defaced.nii"
        ).get_fdata()
    else:
        mask = nb.load(
            "mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c_headmask.nii"
        ).get_fdata()
    mask = torch.Tensor(mask)
    mask = mask[:192, 2:226, :192]
    params = {'batch_size': 1, 'shuffle': False, 'num_workers': 0}
    testing_set = Dataset3dH5(val_ids)
    test_loader = data.DataLoader(testing_set, **params)

    # due to limited memory size compute pca residuals in batches
    #batchsize=8
    #pcaresiduals = pca_project_test(val_ids[:batchsize])
    #offset=0

    vae = VAE(z_dim=512,
              use_cuda=True,
              use_resblocks=True,
              n_downsample=3,
              maxchannel=128,
              zchannel=16,
              variational=False)
    checkpointpath = f'checkpoint/vae_512_128_3_mse_rbvoxnorm.pt_best.pt'
    s = torch.load(checkpointpath, map_location=torch.device("cuda"))
    vae.load_state_dict(s["model"])
    del s
    j = 0
    with h5py.File(f"test_healthy/r.h5", "w") as hf:
        for (x, _) in tqdm(test_loader):
            x = x
            xnorm = (x - voxmean) / voxstd * mask

            rsimple = (xnorm**2).sum(dim=1, keepdim=True)
            h5icov = "icov_x_z_anon.h5" if args.anon else "icov_x_z.h5"
            rmvn, _ = mahalanobis2(deepcopy(x),
                                   filename=h5icov,
                                   use_cuda=False)

            # if j-offset+1 > pcaresiduals.shape[0]:
            #     # next pca batch
            #     del pcaresiduals
            #     offset = j
            #     pcaresiduals = pca_project_test(val_ids[offset:min(offset+batchsize, len(val_ids))], cuda=False)
            # rpca = (pcaresiduals[j-offset:j-offset+1]**2).sum(dim=1, keepdim=True)

            _, rpca = direct_project(x, "z", anon=args.anon)
            rpca = (rpca**2).sum(dim=1, keepdim=True)

            recon = vae(xnorm.cuda()).cpu()
            rvae = ((xnorm - recon)**2).sum(dim=1, keepdim=True)

            r = torch.cat([
                rsimple,
                rmvn,
                rpca,
                rvae,
            ], dim=1)

            if j == 0:
                print(f"creating dataset with size {r.shape}")
                hf.create_dataset('r',
                                  data=r.cpu().numpy().astype(np.float32),
                                  dtype=np.float32,
                                  chunks=(1, r.shape[1], 16, 16, 16),
                                  maxshape=(len(val_ids), 4, 192, 224, 192),
                                  **hdf5plugin.Blosc())
            else:
                hf["r"].resize((hf["r"].shape[0] + 1), axis=0)
                hf["r"][-1:] = r.cpu().numpy().astype(np.float32)
            j += 1
            del r, rpca
            del rvae
            del rsimple, rmvn, xnorm, x
Exemple #9
0
def main(args):
    _, val_ids = create_train_idxs(395)

    with h5py.File(f"icov_x_z.h5", "r") as hfstat:
        voxmean = torch.Tensor(np.moveaxis(hfstat["mean"][:],3,0))
        voxstd = torch.sqrt(torch.Tensor(np.moveaxis(hfstat["var"][:],3,0)))

    if args.anon:
        mask = nb.load("mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c_headmask_defaced.nii").get_fdata()
    else:
        mask = nb.load("mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c_headmask.nii").get_fdata()
    mask = torch.Tensor(mask)
    maskc = mask[:192,2:226,:192]
    
    params = {'batch_size': 1,
        'shuffle': False,
        'num_workers': 0}
    if args.anon:
        dfpatho = pd.read_csv("samples_anon_patho.csv")
        dset1 = Dataset3dNonlinearAnonymized(dfpatho)
    else:
        dfpatho = pd.read_csv("patho_sample44_t1_t1ce_flairfs_t2_t2star_adc_tracew_mprage.csv")
        dset1 = Dataset3dNonlinear(dfpatho, suffix="_t2")
    test_loader = data.DataLoader(dset1, **params)

    vae = VAE(z_dim=512, use_cuda=True, use_resblocks=True, n_downsample=3, maxchannel=128, zchannel=16, variational=False)
    checkpointpath = f'checkpoint/vae_512_128_3_mse_rbvoxnorm.pt_best.pt'
    s = torch.load(checkpointpath, map_location=torch.device("cuda"))
    vae.load_state_dict(s["model"])
    del s
    j=0
    with h5py.File(f"test_patho/r.h5", "w") as hf:
        for (x, _) in tqdm(test_loader):
            x = ztransform(x, mask)
            x = x[:,:,:192,2:226,:192]
            xnorm = (x-voxmean)/voxstd*maskc

            rsimple = (xnorm**2).sum(dim=1, keepdim=True)
            h5icov = "icov_x_z_anon.h5" if args.anon else "icov_x_z.h5"
            rmvn, _ = mahalanobis2(deepcopy(x), filename=h5icov, use_cuda=False)

            _, rpca = direct_project(x, "z", anon=args.anon)
            rpca = (rpca**2).sum(dim=1, keepdim=True)

            recon = vae(xnorm.cuda()).cpu()
            rvae = ((xnorm-recon)**2).sum(dim=1, keepdim=True)

            r = torch.cat([
                rsimple, 
                rmvn, 
                rpca, 
                rvae,
            ], dim=1)
            
            if j == 0:
                print(f"creating dataset with size {r.shape}")
                hf.create_dataset('r', data = r.cpu().numpy().astype(np.float32), dtype=np.float32, chunks=(1,r.shape[1],16,16,16), maxshape=(len(val_ids), r.shape[1], 192, 224, 192), **hdf5plugin.Blosc())
            else:
                hf["r"].resize((hf["r"].shape[0] + 1), axis = 0)
                hf["r"][-1:] = r.cpu().numpy().astype(np.float32)
            j += 1
            del r, rpca
            del rvae
            del rsimple, rmvn,xnorm,x