def direct_project(y, normalizationmode, nostd=False, anon=False): train_ids, _ = create_train_idxs(395) if anon: mask = nb.load("mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c_headmask_defaced.nii").get_fdata() else: mask = nb.load("mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c_headmask.nii").get_fdata() mask = mask[:192,2:226,:192][None,None,:,:,:] mask = torch.Tensor(mask) with h5py.File(f"icov_x_{normalizationmode}.h5", "r") as hfout: mean = np.moveaxis(hfout["mean"][:],3,0)[None,:,:,:,:].astype(np.float32) std = np.sqrt(np.moveaxis(hfout["var"][:],3,0))[None,:,:,:,:].astype(np.float32) y_z = (y-mean)/(1.0 if nostd else std) residual = deepcopy(y_z) with h5py.File(f"x_{normalizationmode}.h5", "r") as hf: x = hf['x'] for ii,i in enumerate(train_ids): print(f"{ii},{i}") xi_z = torch.Tensor((x[i].astype(np.float32)-mean)/(1.0 if nostd else std)) xi_norm = torch.norm(xi_z*mask) coef = (residual*xi_z*mask).sum()/xi_norm print(coef) residual = residual - coef * xi_z/xi_norm #print(np.linalg.norm(residual)) return y_z - residual, residual
def compute_kernelrow(y, normalizationmode, nostd=False): mask = nb.load("mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c_headmask.nii").get_fdata() mask = mask[:192,2:226,:192][None,None,:,:,:] mask = torch.Tensor(mask) Nmask = mask.sum() train_ids, _ = create_train_idxs(395) with h5py.File(f"icov_x_{normalizationmode}.h5", "r") as hfout: mean = np.moveaxis(hfout["mean"][:],3,0)[None,:,:,:,:] std = np.sqrt(np.moveaxis(hfout["var"][:],3,0))[None,:,:,:,:] yz = (y-mean) if not nostd: yz /= std with h5py.File(f"x_{normalizationmode}.h5", "r") as hf: x = hf['x'] kmeanrow = torch.zeros(1,1,1,1,len(train_ids)) for i in train_ids: print(i) xi = torch.Tensor((x[i].astype(np.float32)-mean)/(1.0 if nostd else std)) kmeanrow[0,0,0,0,i] = torch.sum((yz*xi)*mask/Nmask) print(kmeanrow[0,0,0,0,i]) return kmeanrow
def pca_project(y, normalizationmode, nostd=False): train_ids, _ = create_train_idxs(395) #train_ids = train_ids[:10] #kernelrow = compute_kernelrow(y, normalizationmode, nostd) kernelrow = compute_kernelrow_local(y, normalizationmode, nostd) print(kernelrow) with h5py.File(f"localkernel_{normalizationmode}{'_nostd' if nostd else ''}.h5", "r") as hfout: kernel = torch.Tensor(hfout['kernel'][0,0,0]) kernel = kernel[train_ids,:][:,train_ids] eps = kernel.diagonal().mean()/10 coefrow = torch.matmul(kernelrow[0,0,0], torch.inverse(kernel + eps*torch.eye(kernel.shape[-1]))) projection = 0*y with h5py.File(f"icov_x_{normalizationmode}.h5", "r") as hfout: mean = np.moveaxis(hfout["mean"][:],3,0)[None,:,:,:,:] std = np.sqrt(np.moveaxis(hfout["var"][:],3,0))[None,:,:,:,:] y_z = (y-mean)/(1.0 if nostd else std) with h5py.File(f"x_{normalizationmode}.h5", "r") as hf: x = hf['x'] for ii,i in enumerate(train_ids): print(coefrow[0,ii]) print(ii) projection += coefrow[0,ii] * (x[i].astype(np.float32)-mean)/(1.0 if nostd else std) return projection, y_z - projection
def pca_project_test(test_ids, cuda=False): train_ids, _ = create_train_idxs(395) with h5py.File(f"localkernel_z.h5", "r") as hfout: kernel = torch.Tensor(hfout['kernel'][0,0,0]) if cuda: kernel.cuda() kernelrows = kernel[test_ids,:][:,train_ids] kernel = kernel[train_ids,:][:,train_ids] eps = kernel.diagonal().mean()/10 epseye = torch.eye(kernel.shape[-1]) if cuda: epseye.cuda() coefrows = torch.matmul(kernelrows, torch.inverse(kernel + eps*epseye)) print(coefrows) residuals = torch.zeros(len(test_ids),9,192,224,192) if cuda: residuals.cuda() with h5py.File(f"icov_x_z.h5", "r") as hfout: mean = torch.Tensor(np.moveaxis(hfout["mean"][:],3,0)[None,:,:,:,:]) std = torch.Tensor(np.sqrt(np.moveaxis(hfout["var"][:],3,0))[None,:,:,:,:]) if cuda: mean.cuda() std.cuda() print("loading h5 file") with h5py.File(f"x_z.h5", "r") as hf: x = hf['x'] print("loading test data") for ii,i in enumerate(test_ids): print(ii) xi = torch.Tensor(x[i].astype(np.float32)) if cuda: xi.cuda() residuals[ii] = (xi-mean)/std print("updating residuals") for ii,i in enumerate(train_ids): xi = torch.Tensor(x[i].astype(np.float32)) if cuda: xi.cuda() print(ii) residuals -= coefrows[:,ii].view(-1,1,1,1,1) * (xi-mean)/std #print((residuals**2).sum()) return residuals
def compute_kernelrow_local(y, normalizationmode, nostd=False, anon=False): if anon: mask = nb.load("mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c_headmask_defaced.nii").get_fdata() else: mask = nb.load("mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c_headmask.nii").get_fdata() mask = mask[:192,2:226,:192][None,None,:,:,:] nmask = mask.sum() mask = torch.Tensor(mask) train_ids, _ = create_train_idxs(395) #train_ids = train_ids[:10] with h5py.File(f"icov_x_{normalizationmode}.h5", "r") as hfout: mean = np.moveaxis(hfout["mean"][:],3,0)[None,:,:,:,:] std = np.sqrt(np.moveaxis(hfout["var"][:],3,0))[None,:,:,:,:] y_z = (y-mean)/(1.0 if nostd else std) with h5py.File(f"x_{normalizationmode}.h5", "r") as hf: x = hf['x'] n = len(train_ids) chunks = x.chunks kmeanrow = torch.zeros(1,1,1,1,n) Nmask = mask.sum() for i in range(x.shape[2]//x.chunks[2]): for j in range(x.shape[3]//x.chunks[3]): for k in range(x.shape[4]//x.chunks[4]): print([i,j,k]) maxi = (i+1)*chunks[2] maxj = (j+1)*chunks[3] maxk = (k+1)*chunks[4] nmask = mask[:,:, i*x.chunks[2]:maxi, j*x.chunks[3]:maxj, k*x.chunks[4]:maxk ].sum() if nmask == 0: continue r = np.concatenate([x[ii:ii+1,:, i*x.chunks[2]:maxi, j*x.chunks[3]:maxj, k*x.chunks[4]:maxk ] for ii in train_ids], axis=0) r = r.astype(np.float32) ry = y_z[:,:, i*x.chunks[2]:maxi, j*x.chunks[3]:maxj, k*x.chunks[4]:maxk ] r = torch.Tensor(r) #mean = r.mean(dim=0, keepdim=True) r -= mean[:,:, i*x.chunks[2]:maxi, j*x.chunks[3]:maxj, k*x.chunks[4]:maxk ] #ry -= mean #std = r.std(dim=0, keepdim=True) if not nostd: r /= std[:,:, i*x.chunks[2]:maxi, j*x.chunks[3]:maxj, k*x.chunks[4]:maxk ] #ry /= std r = r*mask[:,:, i*x.chunks[2]:maxi, j*x.chunks[3]:maxj, k*x.chunks[4]:maxk ] ry = ry*mask[:,:, i*x.chunks[2]:maxi, j*x.chunks[3]:maxj, k*x.chunks[4]:maxk ] r = r.view(r.shape[0],-1) ry = ry.view(ry.shape[0],-1) localkernelrow = torch.matmul(ry, torch.transpose(r,0,1)).view(1,1,1,1,n)/nmask kmeanrow += localkernelrow*nmask/Nmask print(kmeanrow[0,0,0,0,0]) return kmeanrow
def main(args): # setup MRI data loaders train_ids, val_ids = create_train_idxs(395) if args.voxnorm: h5icov = "icov_x_z_anon.h5" if args.anon else "icov_x_z.h5" # if not os.path.exists("icov_x_z.h5"): # print("calculating voxel stats") # calc_icov(train_ids) # print("done") with h5py.File(h5icov, "r") as hfstat: voxmean = torch.Tensor(np.moveaxis(hfstat["mean"][:], 3, 0)) voxstd = torch.sqrt( torch.Tensor(np.moveaxis(hfstat["var"][:], 3, 0))) if args.cuda: voxmean = voxmean.cuda() voxstd = voxstd.cuda() if args.anon: mask = nb.load( "mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c_headmask_defaced.nii" ).get_fdata() else: mask = nb.load( "mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c_headmask.nii" ).get_fdata() if args.cuda: mask = torch.Tensor(mask).cuda() else: mask = torch.Tensor(mask) mask = mask[:192, 2:226, :192] params = {'batch_size': args.batch, 'shuffle': False, 'num_workers': 0} # train_loader, test_loader h5file = "x_z_anon.h5" if args.anon else "x_z.h5" training_set = Dataset3dH5(train_ids, h5file=h5file) train_loader = data.DataLoader(training_set, **params) testing_set = Dataset3dH5(val_ids, h5file=h5file) test_loader = data.DataLoader(testing_set, **params) if args.cuda: device = 'cuda' else: device = 'cpu' # setup the VAE vae = VAE(z_dim=args.num_zdim, use_cuda=args.cuda, use_resblocks=args.resblocks, n_downsample=args.ndownsample, maxchannel=args.maxchannels, zchannel=16, variational=not args.mse) # setup the optimizer optimizer = optim.Adam(vae.parameters(), lr=args.learning_rate, weight_decay=args.weightdecay) #optimizer = optim.SGD(vae.parameters(), lr=args.learning_rate, weight_decay=0.01) train_elbo = [] test_elbo = [] startepoch = 0 argshash = f"{args.num_zdim}_{args.maxchannels}_{args.ndownsample}{'_mse' if args.mse else '_vae'}{'_l1' if args.l1loss else ''}{'_rb' if args.resblocks else ''}{'voxnorm' if args.voxnorm else ''}{'_maskloss' if args.maskloss else ''}{'_addmin' if args.addmin else ''}" checkpointpath = f'checkpoint/vae_{argshash}.pt' if os.path.exists(checkpointpath): print(f"loading {checkpointpath}") state = torch.load(checkpointpath, map_location=torch.device(device)) vae.load_state_dict(state['model']) for g in state["optimizer"]["param_groups"]: g['lr'] = args.learning_rate optimizer.load_state_dict(state["optimizer"]) train_elbo = state["train_elbo"] test_elbo = state["test_elbo"] startepoch = len(train_elbo) del state['model'] del state['optimizer'] # training loop for epoch in range(startepoch, args.num_epochs): # initialize loss accumulator epoch_loss = 0. # do a training epoch over each mini-batch x returned # by the data loader count = 0 for (x, _) in tqdm(train_loader): # if on GPU put mini-batch into CUDA memory if args.cuda: x = x.cuda() if args.voxnorm: x = (x - voxmean) / voxstd * mask else: x = x * mask if args.addmin: x = x - x.min() optimizer.zero_grad() if not args.mse: recon, mu, logvar = vae(x) if args.maskloss: recon = recon * mask loss, mse = loss_function(recon, x, mu, logvar, mse=not args.l1loss) else: recon = vae(x) if args.maskloss: recon = recon * mask loss = mseloss(recon, x) mse = loss # do ELBO gradient and accumulate loss loss.backward() optimizer.step() #total_norm = np.sum([p.grad.data.norm(2).item() for p in vae.parameters()]) #print(f"gradient norm: {total_norm}") steploss = loss.item() count += 1 epoch_loss += steploss if args.verbose: print( f"loss for step: {steploss}, mse: {mse}, running mean: {epoch_loss/count}, previous epochs: {np.mean(train_elbo)}" ) # report training diagnostics normalizer_train = len(train_loader.dataset) total_epoch_loss_train = epoch_loss / normalizer_train * args.batch train_elbo.append(total_epoch_loss_train) print("[epoch %03d] average training loss: %.4f" % (epoch, total_epoch_loss_train)) checkpointpath = f'checkpoint/vae_{argshash}.pt' if epoch % args.test_frequency == 0 and len(test_loader) > 0: # initialize loss accumulator test_loss = 0. # compute the loss over the entire test set #for i, x in enumerate(test_loader): with torch.no_grad(): for (x, _) in test_loader: # if on GPU put mini-batch into CUDA memory if args.cuda: x = x.cuda() if args.voxnorm: x = (x - voxmean) / voxstd * mask else: x = x * mask if args.addmin: x = x - x.min() # compute ELBO estimate and accumulate loss if not args.mse: mu, logvar = vae.encode(x) recon = vae.decode(mu) if args.maskloss: recon = recon * mask loss, mse = loss_function(recon, x, mu, logvar, mse=not args.l1loss) else: recon = vae(x) if args.maskloss: recon = recon * mask loss = mseloss(recon, x) test_loss += loss.item() # report test diagnostics normalizer_test = len(test_loader.dataset) total_epoch_loss_test = test_loss / normalizer_test * args.batch test_elbo.append(total_epoch_loss_test) print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test)) torch.save( { "args": args, "model": vae.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch, "train_elbo": train_elbo, "test_elbo": test_elbo }, checkpointpath) if min(test_elbo) == test_elbo[-1]: torch.save( { "args": args, "model": vae.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch, "train_elbo": train_elbo, "test_elbo": test_elbo }, checkpointpath + '_best.pt') return vae
ic = ic[:, :, :, modalityidx, :] ic = torch.Tensor(ic) mean0 = mean[i * c[0]:maxi, j * c[1]:maxj, k * c[2]:maxk] mean0 = mean0[:, :, :, modalityidx] mean0 = torch.Tensor(mean0).unsqueeze(4) if use_cuda: ic = ic.cuda() out = out.cuda() mean0 = mean0.cuda() xin = x[0, :, i * c[0]:maxi, j * c[1]:maxj, k * c[2]:maxk].permute(1, 2, 3, 0).unsqueeze(4) xin -= mean0 out[0, 0, i * c[0]:maxi, j * c[1]:maxj, k * c[2]:maxk] = torch.matmul( torch.matmul(ic, xin).permute(0, 1, 2, 4, 3), xin).squeeze() return out, xin.shape[3] if __name__ == '__main__': parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-i', '--input', default="x_z.h5", type=str, help='input h5 file') args = parser.parse_args() print(args) train_ids, _ = create_train_idxs(395) calc_icov(train_ids, filename=args.input)
def main(args): _, val_ids = create_train_idxs(395) with h5py.File(f"icov_x_z.h5", "r") as hfstat: voxmean = torch.Tensor(np.moveaxis(hfstat["mean"][:], 3, 0)) voxstd = torch.sqrt(torch.Tensor(np.moveaxis(hfstat["var"][:], 3, 0))) if args.anon: mask = nb.load( "mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c_headmask_defaced.nii" ).get_fdata() else: mask = nb.load( "mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c_headmask.nii" ).get_fdata() mask = torch.Tensor(mask) mask = mask[:192, 2:226, :192] params = {'batch_size': 1, 'shuffle': False, 'num_workers': 0} testing_set = Dataset3dH5(val_ids) test_loader = data.DataLoader(testing_set, **params) # due to limited memory size compute pca residuals in batches #batchsize=8 #pcaresiduals = pca_project_test(val_ids[:batchsize]) #offset=0 vae = VAE(z_dim=512, use_cuda=True, use_resblocks=True, n_downsample=3, maxchannel=128, zchannel=16, variational=False) checkpointpath = f'checkpoint/vae_512_128_3_mse_rbvoxnorm.pt_best.pt' s = torch.load(checkpointpath, map_location=torch.device("cuda")) vae.load_state_dict(s["model"]) del s j = 0 with h5py.File(f"test_healthy/r.h5", "w") as hf: for (x, _) in tqdm(test_loader): x = x xnorm = (x - voxmean) / voxstd * mask rsimple = (xnorm**2).sum(dim=1, keepdim=True) h5icov = "icov_x_z_anon.h5" if args.anon else "icov_x_z.h5" rmvn, _ = mahalanobis2(deepcopy(x), filename=h5icov, use_cuda=False) # if j-offset+1 > pcaresiduals.shape[0]: # # next pca batch # del pcaresiduals # offset = j # pcaresiduals = pca_project_test(val_ids[offset:min(offset+batchsize, len(val_ids))], cuda=False) # rpca = (pcaresiduals[j-offset:j-offset+1]**2).sum(dim=1, keepdim=True) _, rpca = direct_project(x, "z", anon=args.anon) rpca = (rpca**2).sum(dim=1, keepdim=True) recon = vae(xnorm.cuda()).cpu() rvae = ((xnorm - recon)**2).sum(dim=1, keepdim=True) r = torch.cat([ rsimple, rmvn, rpca, rvae, ], dim=1) if j == 0: print(f"creating dataset with size {r.shape}") hf.create_dataset('r', data=r.cpu().numpy().astype(np.float32), dtype=np.float32, chunks=(1, r.shape[1], 16, 16, 16), maxshape=(len(val_ids), 4, 192, 224, 192), **hdf5plugin.Blosc()) else: hf["r"].resize((hf["r"].shape[0] + 1), axis=0) hf["r"][-1:] = r.cpu().numpy().astype(np.float32) j += 1 del r, rpca del rvae del rsimple, rmvn, xnorm, x
def main(args): _, val_ids = create_train_idxs(395) with h5py.File(f"icov_x_z.h5", "r") as hfstat: voxmean = torch.Tensor(np.moveaxis(hfstat["mean"][:],3,0)) voxstd = torch.sqrt(torch.Tensor(np.moveaxis(hfstat["var"][:],3,0))) if args.anon: mask = nb.load("mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c_headmask_defaced.nii").get_fdata() else: mask = nb.load("mni_icbm152_nlin_asym_09c/mni_icbm152_t1_tal_nlin_asym_09c_headmask.nii").get_fdata() mask = torch.Tensor(mask) maskc = mask[:192,2:226,:192] params = {'batch_size': 1, 'shuffle': False, 'num_workers': 0} if args.anon: dfpatho = pd.read_csv("samples_anon_patho.csv") dset1 = Dataset3dNonlinearAnonymized(dfpatho) else: dfpatho = pd.read_csv("patho_sample44_t1_t1ce_flairfs_t2_t2star_adc_tracew_mprage.csv") dset1 = Dataset3dNonlinear(dfpatho, suffix="_t2") test_loader = data.DataLoader(dset1, **params) vae = VAE(z_dim=512, use_cuda=True, use_resblocks=True, n_downsample=3, maxchannel=128, zchannel=16, variational=False) checkpointpath = f'checkpoint/vae_512_128_3_mse_rbvoxnorm.pt_best.pt' s = torch.load(checkpointpath, map_location=torch.device("cuda")) vae.load_state_dict(s["model"]) del s j=0 with h5py.File(f"test_patho/r.h5", "w") as hf: for (x, _) in tqdm(test_loader): x = ztransform(x, mask) x = x[:,:,:192,2:226,:192] xnorm = (x-voxmean)/voxstd*maskc rsimple = (xnorm**2).sum(dim=1, keepdim=True) h5icov = "icov_x_z_anon.h5" if args.anon else "icov_x_z.h5" rmvn, _ = mahalanobis2(deepcopy(x), filename=h5icov, use_cuda=False) _, rpca = direct_project(x, "z", anon=args.anon) rpca = (rpca**2).sum(dim=1, keepdim=True) recon = vae(xnorm.cuda()).cpu() rvae = ((xnorm-recon)**2).sum(dim=1, keepdim=True) r = torch.cat([ rsimple, rmvn, rpca, rvae, ], dim=1) if j == 0: print(f"creating dataset with size {r.shape}") hf.create_dataset('r', data = r.cpu().numpy().astype(np.float32), dtype=np.float32, chunks=(1,r.shape[1],16,16,16), maxshape=(len(val_ids), r.shape[1], 192, 224, 192), **hdf5plugin.Blosc()) else: hf["r"].resize((hf["r"].shape[0] + 1), axis = 0) hf["r"][-1:] = r.cpu().numpy().astype(np.float32) j += 1 del r, rpca del rvae del rsimple, rmvn,xnorm,x