def __call__(self, sample): # Crop input image for k in self.keys: x = sample[k] nx, ny = x.shape[-2], x.shape[-1] sample[k] = center_crop(x, (min(nx, self.resolution), min(ny, self.resolution))) return sample
def run_sn(args, data_loader, model): """ Run Sigmanet """ model.eval() logging.info(f'Run Sigmanet reconstruction') logging.info(f'Arguments: {args}') reconstructions = defaultdict(list) # keys = ['input', 'kspace', 'smaps', 'mask', 'fg_mask'] # if args.mask_bg: # keys.append('input_rss_mean') # attr_keys = ['mean', 'cov', 'norm'] with torch.no_grad(): for ii, sample in enumerate(tqdm(iter(data_loader))): sample = data_batch._read_data(sample, device=args.device) rec_x = sample['attrs']['metadata']['rec_x'] rec_y = sample['attrs']['metadata']['rec_y'] x = model(sample['input'], sample['kspace'], sample['smaps'], sample['mask'], sample['attrs']) recons = postprocess(x, (rec_x, rec_y)) # mask background using background mean value if args.mask_bg: fg_mask = center_crop( sample['fg_mask'], (rec_x, rec_y), ).squeeze(1) if args.use_bg_noise_mean: bg_mean = sample['input_rss_mean'].reshape(-1, 1, 1) recons = recons * fg_mask + (1 - fg_mask) * bg_mean else: recons = recons * fg_mask # renormalize norm = sample['attrs']['norm'].reshape(-1, 1, 1) recons = recons * norm recons = recons.to('cpu').numpy() if args.debug and ii % 10 == 0: plt.imsave( 'run_sn_progress.png', np.hstack(recons), cmap='gray', ) for bidx in range(recons.shape[0]): reconstructions[sample['fname']].append( (sample['slidx'][bidx], recons[bidx])) reconstructions = { fname: np.stack([pred for _, pred in sorted(slice_preds)]) for fname, slice_preds in reconstructions.items() } save_reconstructions(reconstructions, args.out_dir)
def postprocess(tensor, shape): """Postprocess the tensor to be magnitude image and crop to the ROI, which is (min(nFE, shape[0]), min(nPE, shape[1]). The method expects either a tensor representing complex values (with shape [bs, nsmaps, nx, ny, 2]) or a real-valued tensor (with shape [bs, nsmaps, nx, ny]) """ if tensor.shape[-1] == 2: tensor = mytorch.mri.root_sum_of_squares(tensor, dim=(1, -1), eps=1e-9) cropsize = (min(tensor.shape[-2], shape[0]), min(tensor.shape[-1], shape[1])) return center_crop(tensor, cropsize)
def run_zero_filled_sense(args, data_loader): """ Run Adjoint (zero-filled SENSE) reconstruction """ logging.info('Run zero-filled SENSE reconstruction') logging.info(f'Arguments: {args}') reconstructions = defaultdict(list) with torch.no_grad(): for sample in tqdm(iter(data_loader)): sample = data_batch._read_data(sample) rec_x = sample['attrs']['metadata']['rec_x'] rec_y = sample['attrs']['metadata']['rec_y'] x = sample['input'] recons = postprocess(x, (rec_x, rec_y)) # mask background using background mean value if args.mask_bg: fg_mask = center_crop( sample['fg_mask'], (rec_x, rec_y), ).squeeze(1) if args.use_bg_noise_mean: bg_mean = sample['input_rss_mean'].reshape(-1, 1, 1) recons = recons * fg_mask + (1 - fg_mask) * bg_mean else: recons = recons * fg_mask # renormalize norm = sample['attrs']['norm'].numpy()[:, np.newaxis, np.newaxis] recons = recons.numpy() * norm for bidx in range(recons.shape[0]): reconstructions[sample['fname']].append( (sample['slidx'][bidx], recons[bidx])) reconstructions = { fname: np.stack([pred for _, pred in sorted(slice_preds)]) for fname, slice_preds in reconstructions.items() } save_reconstructions(reconstructions, args.out_dir)
def unpad2d(self, tensor, shape): if tensor.shape == shape: return tensor else: return center_crop(tensor, shape)