def denoise_image(mic, models, lowpass=1, cutoff=0, gaus=None, inv_gaus=None, deconvolve=False, deconv_patch=1, patch_size=-1, padding=0, normalize=False, use_cuda=False): if lowpass > 1: mic = dn.lowpass(mic, lowpass) mic = torch.from_numpy(mic) if use_cuda: mic = mic.cuda() # normalize and remove outliers mu = mic.mean() std = mic.std() x = (mic - mu) / std if cutoff > 0: x[(x < -cutoff) | (x > cutoff)] = 0 # apply guassian/inverse gaussian filter if gaus is not None: x = dn.denoise(gaus, x) elif inv_gaus is not None: x = dn.denoise(inv_gaus, x) elif deconvolve: # estimate optimal filter and correct spatial correlation x = dn.correct_spatial_covariance(x, patch=deconv_patch) # denoise mic = 0 for model in models: mic += dn.denoise(model, x, patch_size=patch_size, padding=padding) mic /= len(models) # restore pixel scaling if normalize: mic = (mic - mic.mean()) / mic.std() else: # add back std. dev. and mean mic = std * mic + mu # back to numpy/cpu mic = mic.cpu().numpy() return mic
def main(args): ## set the device use_cuda = False if args.device >= 0: use_cuda = torch.cuda.is_available() if use_cuda: torch.cuda.set_device(args.device) print('# using device={} with cuda={}'.format(args.device, use_cuda), file=sys.stderr) do_train = (args.dir_a is not None and args.dir_b is not None) or (args.hdf is not None) if do_train: if args.hdf is None: #use dirA/dirB crop = args.crop dir_a = args.dir_a dir_b = args.dir_b random = np.random.RandomState(44444) dataset_train, dataset_val = make_paired_images_datasets( dir_a, dir_b, crop, random=random) shuffle = True else: # make HDF datasets dataset_train, dataset_val = make_hdf5_datasets(args.hdf) shuffle = False # initialize the model #model = dn.DenoiseNet(32) model = dn.UDenoiseNet() if use_cuda: model = model.cuda() # train lr = args.lr batch_size = args.batch_size num_epochs = args.num_epochs num_workers = args.num_workers print('epoch', 'loss_train', 'loss_val') #criteria = nn.L1Loss() criteria = args.criteria for epoch, loss_train, loss_val in dn.train_noise2noise( model, dataset_train, lr=lr, batch_size=batch_size, criteria=criteria, num_epochs=num_epochs, dataset_val=dataset_val, use_cuda=use_cuda, num_workers=num_workers, shuffle=shuffle): print(epoch, loss_train, loss_val) sys.stdout.flush() # save the model if args.save_prefix is not None: path = args.save_prefix + '_epoch{}.sav'.format(epoch) model.cpu() model.eval() torch.save(model, path) if use_cuda: model.cuda() else: # load the saved model if args.model in ['L0', 'L1', 'L2']: if args.model in ['L0', 'L1']: print( 'ERROR: L0 and L1 models are not implemented in the current version', file=sys.stderr) sys.exit(1) model = dn.load_model(args.model) else: model = torch.load(args.model) print('# using model:', args.model, file=sys.stderr) model.eval() if use_cuda: model.cuda() if args.stack: # we are denoising a single MRC stack with open(args.micrographs[0], 'rb') as f: content = f.read() stack, _, _ = mrc.parse(content) print('# denoising stack with shape:', stack.shape, file=sys.stderr) denoised = dn.denoise_stack(model, stack, use_cuda=use_cuda) # write the denoised stack path = args.output print('# writing', path, file=sys.stderr) with open(path, 'wb') as f: mrc.write(f, denoised) else: # using trained model # stream the micrographs and denoise as we go normalize = args.normalize if args.format_ == 'png' or args.format_ == 'jpg': # always normalize png and jpg format normalize = True format_ = args.format_ count = 0 total = len(args.micrographs) bin_ = args.bin ps = args.patch_size padding = args.patch_padding # now, stream the micrographs and denoise them for path in args.micrographs: name, _ = os.path.splitext(os.path.basename(path)) mic = np.array(load_image(path), copy=False) if bin_ > 1: mic = downsample(mic, bin_) mu = mic.mean() std = mic.std() # denoise mic = (mic - mu) / std mic = dn.denoise(model, mic, patch_size=ps, padding=padding, use_cuda=use_cuda) if normalize: mic = (mic - mic.mean()) / mic.std() else: # add back std. dev. and mean mic = std * mic + mu # write the micrograph outpath = args.output + os.sep + name + '.' + format_ save_image(mic, outpath) count += 1 print('# {} of {} completed.'.format(count, total), file=sys.stderr, end='\r') print('', file=sys.stderr)