def main(args): # set the number of threads num_threads = args.num_threads from topaz.torch import set_num_threads set_num_threads(num_threads) ## set the device use_cuda = topaz.cuda.set_device(args.device) print('# using device={} with cuda={}'.format(args.device, use_cuda), file=sys.stderr) cutoff = args.pixel_cutoff # pixel truncation limit do_train = (args.dir_a is not None and args.dir_b is not None) or (args.hdf is not None) if do_train: method = args.method paired = (method == 'noise2noise') preload = args.preload holdout = args.holdout # fraction of image pairs to holdout for validation if args.hdf is None: #use dirA/dirB crop = args.crop dir_as = args.dir_a dir_bs = args.dir_b dset_train = [] dset_val = [] for dir_a, dir_b in zip(dir_as, dir_bs): random = np.random.RandomState(44444) if paired: dataset_train, dataset_val = make_paired_images_datasets( dir_a, dir_b, crop, random=random, holdout=holdout, preload=preload, cutoff=cutoff) else: dataset_train, dataset_val = make_images_datasets( dir_a, dir_b, crop, cutoff=cutoff, random=random, holdout=holdout) dset_train.append(dataset_train) dset_val.append(dataset_val) dataset_train = dset_train[0] for i in range(1, len(dset_train)): dataset_train.x += dset_train[i].x if paired: dataset_train.y += dset_train[i].y dataset_val = dset_val[0] for i in range(1, len(dset_val)): dataset_val.x += dset_val[i].x if paired: dataset_val.y += dset_val[i].y shuffle = True else: # make HDF datasets dataset_train, dataset_val = make_hdf5_datasets(args.hdf, paired=paired, cutoff=cutoff, holdout=holdout, preload=preload) shuffle = preload # initialize the model arch = args.arch if arch == 'unet': model = dn.UDenoiseNet() elif arch == 'unet-small': model = dn.UDenoiseNetSmall() elif arch == 'unet2': model = dn.UDenoiseNet2() elif arch == 'unet3': model = dn.UDenoiseNet3() elif arch == 'fcnet': model = dn.DenoiseNet(32) elif arch == 'fcnet2': model = dn.DenoiseNet2(64) elif arch == 'affine': model = dn.AffineDenoise() else: raise Exception('Unknown architecture: ' + arch) if use_cuda: model = model.cuda() # train optim = args.optim lr = args.lr batch_size = args.batch_size num_epochs = args.num_epochs digits = int(np.ceil(np.log10(num_epochs))) num_workers = args.num_workers print('epoch', 'loss_train', 'loss_val') #criteria = nn.L1Loss() criteria = args.criteria if method == 'noise2noise': iterator = dn.train_noise2noise(model, dataset_train, lr=lr, optim=optim, batch_size=batch_size, criteria=criteria, num_epochs=num_epochs, dataset_val=dataset_val, use_cuda=use_cuda, num_workers=num_workers, shuffle=shuffle) elif method == 'masked': iterator = dn.train_mask_denoise(model, dataset_train, lr=lr, optim=optim, batch_size=batch_size, criteria=criteria, num_epochs=num_epochs, dataset_val=dataset_val, use_cuda=use_cuda, num_workers=num_workers, shuffle=shuffle) for epoch, loss_train, loss_val in iterator: print(epoch, loss_train, loss_val) sys.stdout.flush() # save the model if args.save_prefix is not None: path = args.save_prefix + ('_epoch{:0' + str(digits) + '}.sav').format(epoch) #path = args.save_prefix + '_epoch{}.sav'.format(epoch) model.cpu() model.eval() torch.save(model, path) if use_cuda: model.cuda() models = [model] else: # load the saved model(s) models = [] for arg in args.model: if arg == 'none': print('# Warning: no denoising model will be used', file=sys.stderr) else: print('# Loading model:', arg, file=sys.stderr) model = dn.load_model(arg) model.eval() if use_cuda: model.cuda() models.append(model) # using trained model # denoise the images normalize = args.normalize if args.format_ == 'png' or args.format_ == 'jpg': # always normalize png and jpg format normalize = True format_ = args.format_ suffix = args.suffix lowpass = args.lowpass gaus = args.gaussian if gaus > 0: gaus = dn.GaussianDenoise(gaus) if use_cuda: gaus.cuda() else: gaus = None inv_gaus = args.inv_gaussian if inv_gaus > 0: inv_gaus = dn.InvGaussianFilter(inv_gaus) if use_cuda: inv_gaus.cuda() else: inv_gaus = None deconvolve = args.deconvolve deconv_patch = args.deconv_patch ps = args.patch_size padding = args.patch_padding count = 0 # we are denoising a single MRC stack if args.stack: with open(args.micrographs[0], 'rb') as f: content = f.read() stack, _, _ = mrc.parse(content) print('# denoising stack with shape:', stack.shape, file=sys.stderr) total = len(stack) denoised = np.zeros_like(stack) for i in range(len(stack)): mic = stack[i] # process and denoise the micrograph mic = denoise_image(mic, models, lowpass=lowpass, cutoff=cutoff, gaus=gaus, inv_gaus=inv_gaus, deconvolve=deconvolve, deconv_patch=deconv_patch, patch_size=ps, padding=padding, normalize=normalize, use_cuda=use_cuda) denoised[i] = mic count += 1 print('# {} of {} completed.'.format(count, total), file=sys.stderr, end='\r') print('', file=sys.stderr) # write the denoised stack path = args.output print('# writing', path, file=sys.stderr) with open(path, 'wb') as f: mrc.write(f, denoised) else: # stream the micrographs and denoise them total = len(args.micrographs) # make the output directory if it doesn't exist if not os.path.exists(args.output): os.makedirs(args.output) for path in args.micrographs: name, _ = os.path.splitext(os.path.basename(path)) mic = np.array(load_image(path), copy=False).astype(np.float32) # process and denoise the micrograph mic = denoise_image(mic, models, lowpass=lowpass, cutoff=cutoff, gaus=gaus, inv_gaus=inv_gaus, deconvolve=deconvolve, deconv_patch=deconv_patch, patch_size=ps, padding=padding, normalize=normalize, use_cuda=use_cuda) # write the micrograph if not args.output: if suffix == '' or suffix is None: suffix = '.denoised' # write the file to the same location as input no_ext, ext = os.path.splitext(path) outpath = no_ext + suffix + '.' + format_ else: outpath = args.output + os.sep + name + suffix + '.' + format_ save_image(mic, outpath) #, mi=None, ma=None) count += 1 print('# {} of {} completed.'.format(count, total), file=sys.stderr, end='\r') print('', file=sys.stderr)
def denoise(model, path, outdir, patch_size=128, padding=128, batch_size=1): with open(path, 'rb') as f: content = f.read() tomo,header,_ = mrc.parse(content) name = os.path.basename(path) mu = tomo.mean() std = tomo.std() # denoise in patches d = next(iter(model.parameters())).device denoised = np.zeros_like(tomo) with torch.no_grad(): if patch_size < 1: x = (tomo - mu)/std x = torch.from_numpy(x).to(d) x = model(x.unsqueeze(0).unsqueeze(0)).squeeze().cpu().numpy() x = std*x + mu denoised[:] = x else: patch_data = PatchDataset(tomo, patch_size, padding) total = len(patch_data) count = 0 batch_iterator = torch.utils.data.DataLoader(patch_data, batch_size=batch_size) for index,x in batch_iterator: x = x.to(d) x = (x - mu)/std x = x.unsqueeze(1) # batch x channel # denoise x = model(x).squeeze(1).cpu().numpy() # stitch into denoised volume for b in range(len(x)): i,j,k = index[b] xb = x[b] patch = denoised[i:i+patch_size,j:j+patch_size,k:k+patch_size] pz,py,px = patch.shape xb = xb[padding:padding+pz,padding:padding+py,padding:padding+px] denoised[i:i+patch_size,j:j+patch_size,k:k+patch_size] = xb count += 1 print('# [{}/{}] {:.2%}'.format(count, total, count/total), name, file=sys.stderr, end='\r') print(' '*100, file=sys.stderr, end='\r') ## save the denoised tomogram outpath = outdir + os.sep + name # use the read header except for a few fields header = header._replace(mode=2) # 32-bit real header = header._replace(amin=denoised.min()) header = header._replace(amax=denoised.max()) header = header._replace(amean=denoised.mean()) with open(outpath, 'wb') as f: mrc.write(f, denoised, header=header)
def save_mrc(x, path): with open(path, 'wb') as f: x = x[np.newaxis] # need to add z-axis for mrc write mrc.write(f, x)
def denoise(model, path, outdir, suffix, patch_size=128, padding=128, batch_size=1, volume_num=1, total_volumes=1): with open(path, 'rb') as f: content = f.read() tomo, header, extended_header = mrc.parse(content) tomo = tomo.astype(np.float32) name = os.path.basename(path) mu = tomo.mean() std = tomo.std() # denoise in patches d = next(iter(model.parameters())).device denoised = np.zeros_like(tomo) with torch.no_grad(): if patch_size < 1: x = (tomo - mu) / std x = torch.from_numpy(x).to(d) x = model(x.unsqueeze(0).unsqueeze(0)).squeeze().cpu().numpy() x = std * x + mu denoised[:] = x else: patch_data = PatchDataset(tomo, patch_size, padding) total = len(patch_data) count = 0 batch_iterator = torch.utils.data.DataLoader(patch_data, batch_size=batch_size) for index, x in batch_iterator: x = x.to(d) x = (x - mu) / std x = x.unsqueeze(1) # batch x channel # denoise x = model(x) x = x.squeeze(1).cpu().numpy() # restore original statistics x = std * x + mu # stitch into denoised volume for b in range(len(x)): i, j, k = index[b] xb = x[b] patch = denoised[i:i + patch_size, j:j + patch_size, k:k + patch_size] pz, py, px = patch.shape xb = xb[padding:padding + pz, padding:padding + py, padding:padding + px] denoised[i:i + patch_size, j:j + patch_size, k:k + patch_size] = xb count += 1 print('# [{}/{}] {:.2%}'.format(volume_num, total_volumes, count / total), name, file=sys.stderr, end='\r') print(' ' * 100, file=sys.stderr, end='\r') ## save the denoised tomogram if outdir is None: # write denoised tomogram to same location as input, but add the suffix if suffix is None: # use default suffix = '.denoised' no_ext, ext = os.path.splitext(path) outpath = no_ext + suffix + ext else: if suffix is None: suffix = '' no_ext, ext = os.path.splitext(name) outpath = outdir + os.sep + no_ext + suffix + ext # use the read header except for a few fields header = header._replace(mode=2) # 32-bit real header = header._replace(amin=denoised.min()) header = header._replace(amax=denoised.max()) header = header._replace(amean=denoised.mean()) with open(outpath, 'wb') as f: mrc.write(f, denoised, header=header, extended_header=extended_header)
def main(args): ## set the device use_cuda = False if args.device >= 0: use_cuda = torch.cuda.is_available() if use_cuda: torch.cuda.set_device(args.device) print('# using device={} with cuda={}'.format(args.device, use_cuda), file=sys.stderr) do_train = (args.dir_a is not None and args.dir_b is not None) or (args.hdf is not None) if do_train: if args.hdf is None: #use dirA/dirB crop = args.crop dir_a = args.dir_a dir_b = args.dir_b random = np.random.RandomState(44444) dataset_train, dataset_val = make_paired_images_datasets( dir_a, dir_b, crop, random=random) shuffle = True else: # make HDF datasets dataset_train, dataset_val = make_hdf5_datasets(args.hdf) shuffle = False # initialize the model #model = dn.DenoiseNet(32) model = dn.UDenoiseNet() if use_cuda: model = model.cuda() # train lr = args.lr batch_size = args.batch_size num_epochs = args.num_epochs num_workers = args.num_workers print('epoch', 'loss_train', 'loss_val') #criteria = nn.L1Loss() criteria = args.criteria for epoch, loss_train, loss_val in dn.train_noise2noise( model, dataset_train, lr=lr, batch_size=batch_size, criteria=criteria, num_epochs=num_epochs, dataset_val=dataset_val, use_cuda=use_cuda, num_workers=num_workers, shuffle=shuffle): print(epoch, loss_train, loss_val) sys.stdout.flush() # save the model if args.save_prefix is not None: path = args.save_prefix + '_epoch{}.sav'.format(epoch) model.cpu() model.eval() torch.save(model, path) if use_cuda: model.cuda() else: # load the saved model if args.model in ['L0', 'L1', 'L2']: if args.model in ['L0', 'L1']: print( 'ERROR: L0 and L1 models are not implemented in the current version', file=sys.stderr) sys.exit(1) model = dn.load_model(args.model) else: model = torch.load(args.model) print('# using model:', args.model, file=sys.stderr) model.eval() if use_cuda: model.cuda() if args.stack: # we are denoising a single MRC stack with open(args.micrographs[0], 'rb') as f: content = f.read() stack, _, _ = mrc.parse(content) print('# denoising stack with shape:', stack.shape, file=sys.stderr) denoised = dn.denoise_stack(model, stack, use_cuda=use_cuda) # write the denoised stack path = args.output print('# writing', path, file=sys.stderr) with open(path, 'wb') as f: mrc.write(f, denoised) else: # using trained model # stream the micrographs and denoise as we go normalize = args.normalize if args.format_ == 'png' or args.format_ == 'jpg': # always normalize png and jpg format normalize = True format_ = args.format_ count = 0 total = len(args.micrographs) bin_ = args.bin ps = args.patch_size padding = args.patch_padding # now, stream the micrographs and denoise them for path in args.micrographs: name, _ = os.path.splitext(os.path.basename(path)) mic = np.array(load_image(path), copy=False) if bin_ > 1: mic = downsample(mic, bin_) mu = mic.mean() std = mic.std() # denoise mic = (mic - mu) / std mic = dn.denoise(model, mic, patch_size=ps, padding=padding, use_cuda=use_cuda) if normalize: mic = (mic - mic.mean()) / mic.std() else: # add back std. dev. and mean mic = std * mic + mu # write the micrograph outpath = args.output + os.sep + name + '.' + format_ save_image(mic, outpath) count += 1 print('# {} of {} completed.'.format(count, total), file=sys.stderr, end='\r') print('', file=sys.stderr)