def time_average(data, dim, eps=1e-6, keepdim=True): """ Computes time average across a specified axis. """ mask = cplx.get_mask(data) return data.sum(dim, keepdim=keepdim) / (mask.sum(dim, keepdim=keepdim) + eps)
def visualize(args, epoch, model, data_loader, writer, is_training=True): def save_image(image, tag): image = image.permute(0,3,1,2) image -= image.min() image /= image.max() grid = torchvision.utils.make_grid(image, nrow=1, pad_value=1) writer.add_image(tag, grid, epoch) model.eval() with torch.no_grad(): for iter, data in enumerate(data_loader): # Load all data arrays input, maps, L, R, target, mean, std, norm = data input = input.to(args.device) maps = maps.to(args.device) L = L.to(args.device).squeeze(0) R = R.to(args.device).squeeze(0) target = target.to(args.device) # Data dimensions (for my own reference) # image size: [batch_size, nx, ny, nt, nmaps, 2] # kspace size: [batch_size, nkx, nky, nt, ncoils, 2] # maps size: [batch_size, nkx, ny, 1, ncoils, nmaps, 2] # Compute DL recon output, summary_data = model(input, maps, initial_guess=(L, R)) # Get initial guess init = summary_data['init_image'] # Slice images init = init[:,:,:,10,0,None] output = output[:,:,:,10,0,None] target = target[:,:,:,10,0,None] mask = cplx.get_mask(input[:,-1,:,:,0,:]) # [b, y, t, 2] # Save images to summary tag = 'Train' if is_training else 'Val' all_images = torch.cat((init, output, target), dim=2) save_image(cplx.abs(all_images), '%s_Images' % tag) save_image(cplx.angle(all_images), '%s_Phase' % tag) save_image(cplx.abs(output - target), '%s_Error' % tag) save_image(mask.permute(0,2,1,3), '%s_Mask' % tag) # Save scalars to summary for i in range(args.num_grad_steps): step_size_L = summary_data['step_size_L_%d' % i] writer.add_scalar('step_sizes/L%d' % i, step_size_L.item(), epoch) step_size_R = summary_data['step_size_R_%d' % i] writer.add_scalar('step_sizes/R%d' % i, step_size_R.item(), epoch) break
def visualize(args, epoch, model, data_loader, writer, is_training=True): def save_image(image, tag, shape=None): image = image.permute(0, 3, 1, 2) image -= image.min() image /= image.max() if shape is not None: image = torch.nn.functional.interpolate(image, size=shape, mode='bilinear', align_corners=True) grid = torchvision.utils.make_grid(image, nrow=1, pad_value=1) writer.add_image(tag, grid, epoch) model.eval() with torch.no_grad(): for iter, data in enumerate(data_loader): # Load all data arrays input, maps, target, mean, std, norm = data input = input.to(args.device) maps = maps.to(args.device) target = target.to(args.device) # Compute zero-filled recon A = T.SenseModel(maps) zf = A(input, adjoint=True) # Compute DL recon output = model(input, maps) # Slice images [b, y, z, e, 2] init = zf[:, :, :, 0, None] output = output[:, :, :, 0, None] target = target[:, :, :, 0, None] mask = cplx.get_mask(input[:, :, :, 0]) # [b, y, t, 2] # Save images to summary tag = 'Train' if is_training else 'Val' all_images = torch.cat((init, output, target), dim=2) save_image(cplx.abs(all_images), '%s_Images' % tag, shape=[320, 3 * 320]) save_image(cplx.angle(all_images), '%s_Phase' % tag, shape=[320, 3 * 320]) save_image(cplx.abs(output - target), '%s_Error' % tag, shape=[320, 320]) save_image(mask.permute(0, 2, 1, 3), '%s_Mask' % tag) break
def visualize(args, epoch, model, data_loader, writer, is_training=True): def save_image(image, tag): image = image.permute(0, 3, 1, 2) image -= image.min() image /= image.max() grid = torchvision.utils.make_grid(image, nrow=1, pad_value=1) writer.add_image(tag, grid, epoch) model.eval() with torch.no_grad(): for iter, data in enumerate(data_loader): # Load all data arrays input, maps, init, target, mean, std, norm = data input = input.to(args.device) maps = maps.to(args.device) init = init.to(args.device) target = target.to(args.device) # Data dimensions (for my own reference) # image size: [batch_size, nx, ny, nt, nmaps, 2] # kspace size: [batch_size, nkx, nky, nt, ncoils, 2] # maps size: [batch_size, nkx, ny, 1, ncoils, nmaps, 2] # Initialize signal model A = T.SenseModel(maps) # Compute DL recon output = model(input, maps, init_image=init) # Slice images init = init[:, :, :, 10, 0, None] output = output[:, :, :, 10, 0, None] target = target[:, :, :, 10, 0, None] mask = cplx.get_mask(input[:, -1, :, :, 0, :]) # [b, y, t, 2] # Save images to summary tag = 'Train' if is_training else 'Val' all_images = torch.cat((init, output, target), dim=2) save_image(cplx.abs(all_images), '%s_Images' % tag) save_image(cplx.angle(all_images), '%s_Phase' % tag) save_image(cplx.abs(output - target), '%s_Error' % tag) save_image(mask.permute(0, 2, 1, 3), '%s_Mask' % tag) break
def forward(self, kspace, maps, init_image=None, mask=None): """ Args: kspace (torch.Tensor): Input tensor of shape [batch_size, height, width, time, num_coils, 2] maps (torch.Tensor): Input tensor of shape [batch_size, height, width, 1, num_coils, num_emaps, 2] mask (torch.Tensor): Input tensor of shape [batch_size, height, width, time, 1, 1] Returns: (torch.Tensor): Output tensor of shape [batch_size, height, width, time, num_emaps, 2] """ if self.num_emaps != maps.size()[-2]: raise ValueError( 'Incorrect number of ESPIRiT maps! Re-prep data...') if mask is None: mask = cplx.get_mask(kspace) kspace *= mask # Get data dimensions dims = tuple(kspace.size()) # Declare signal model A = SenseModel(maps, weights=mask) # Compute zero-filled image reconstruction zf_image = A(kspace, adjoint=True) image = zf_image if init_image is None else init_image # Begin unrolled proximal gradient descent for resnet, step_size in zip(self.resnets, self.step_sizes): # dc update grad_x = A(A(image), adjoint=True) - zf_image image = image + step_size * grad_x # prox update image = image.reshape(dims[0:4] + (self.num_emaps * 2, )).permute( 0, 4, 3, 2, 1) image = resnet(image) image = image.permute(0, 4, 3, 2, 1).reshape(dims[0:4] + (self.num_emaps, 2)) return image
def forward(self, kspace, maps, initial_guess=None, mask=None): """ Args: kspace (torch.Tensor): Input tensor of shape [batch_size, height, width, time, num_coils, 2] maps (torch.Tensor): Input tensor of shape [batch_size, height, width, 1, num_coils, num_emaps, 2] mask (torch.Tensor): Input tensor of shape [batch_size, height, width, time, 1, 1] Intermediate variables: Spatial basis vectors: [batch_size, block_size, block_size, 1, num_emaps, num_basis, 2] Temporal basis vectors: [batch_size, 1, 1, time, 1, num_basis, 2] Returns: (torch.Tensor): Output tensor of shape [batch_size, height, width, time, num_emaps, 2] """ summary_data = {} if self.num_emaps != maps.size()[-2]: raise ValueError( 'Incorrect number of ESPIRiT maps! Re-prep data...') image_shape = kspace.shape[0:4] + (self.num_emaps, 2) if mask is None: mask = cplx.get_mask(kspace) # Declare linear operators A = SenseModel(maps, weights=mask) BlockOp = ArrayToBlocks(self.block_size, image_shape, overlapping=self.overlapping) # Compute zero-filled image reconstruction zf_image = A(kspace, adjoint=True) # Get initial guess for L, R basis vectors if initial_guess is None: L, R = decompose_LR(zf_image, block_op=BlockOp) else: L, R = initial_guess image = self.compose_LR(L, R, BlockOp) # save into summary summary_data['init_image'] = image # Begin unrolled alternating minimization for i, (sp_resnet, t_resnet) in enumerate(zip(self.sp_resnets, self.t_resnets)): # Save previous L,R variables L_prev = L R_prev = R # Compute gradients of ||Y - ALR'||_2 w.r.t. L, R grad_x = BlockOp(A(A(image), adjoint=True) - zf_image).unsqueeze(-2) L = torch.sum(cplx.mul(grad_x, R_prev), keepdim=True, dim=3) R = torch.sum(cplx.mul(cplx.conj(grad_x), L_prev), keepdim=True, dim=(1, 2, 4)) # L, R model updates step_size_L, step_size_R = self.get_step_sizes(L_prev, R_prev) L = L_prev + step_size_L * L R = R_prev + step_size_R * R # L, R network updates L, R = self.reshape_LR(L, L_prev.shape, R, R_prev.shape, beforeNet=True) L, R = sp_resnet(L), t_resnet(R) L, R = self.reshape_LR(L, L_prev.shape, R, R_prev.shape, beforeNet=False) # Get current image estimate image = self.compose_LR(L, R, BlockOp) # Save summary variables summary_data['image_%d' % i] = image summary_data['step_size_L_%d' % i] = step_size_L summary_data['step_size_R_%d' % i] = step_size_R return image, summary_data