def visualize(args, epoch, model_mag, model_pha, model_vs, data_loader, writer): def save_image(image, tag): image -= image.min() image /= image.max() grid = torchvision.utils.make_grid(image, nrow=4, pad_value=1) writer.add_image(tag, grid, epoch) model_mag.eval() model_mag.eval() model_vs.eval() with torch.no_grad(): for iter, data in enumerate(data_loader): mag_us,mag_gt,pha_us,pha_gt,ksp_us ,img_us,img_gt,img_us_np,img_gt_np,sens,mask,_,_ = data inp_mag = mag_us.unsqueeze(1).to(args.device) tgt_mag = mag_gt.unsqueeze(1).to(args.device) inp_pha = pha_us.unsqueeze(1).to(args.device) tgt_pha = pha_gt.unsqueeze(1).to(args.device) ksp_us = ksp_us.to(args.device) sens = sens.to(args.device) mask = mask.to(args.device) img_gt = img_gt.to(args.device) out_mag = model_mag(inp_mag) out_pha = model_pha(inp_pha) out_mag_unpad = T.unpad(out_mag , ksp_us) out_pha_unpad = T.unpad(out_pha , ksp_us) out_cmplx = T.dc(out_mag,out_pha,ksp_us,sens,mask) if (epoch >= 2): out_cmplx = model_vs(out_cmplx,ksp_us,sens,mask) save_image(inp_mag, 'Img_mag') save_image(tgt_mag, 'Tgt_mag') save_image(inp_pha, 'Img_pha') save_image(tgt_pha, 'Tgt_pha') img_gt_cmplx_abs = (torch.sqrt(img_gt[:,:,:,0]**2 + img_gt[:,:,:,1]**2)).unsqueeze(1).to(args.device) out_cmplx_abs = (torch.sqrt(out_cmplx[:,:,:,0]**2 + out_cmplx[:,:,:,1]**2)).unsqueeze(1).to(args.device) error_cmplx = torch.abs(out_cmplx.cuda() - img_gt.cuda()) error_cmplx_abs = (torch.sqrt(error_cmplx[:,:,:,0]**2 + error_cmplx[:,:,:,1]**2)).unsqueeze(1).to(args.device) out_cmplx_abs = T.pad(out_cmplx_abs[0,0,:,:],[256,256]).unsqueeze(0).unsqueeze(1).to(args.device) error_cmplx_abs = T.pad(error_cmplx_abs[0,0,:,:],[256,256]).unsqueeze(0).unsqueeze(1).to(args.device) img_gt_cmplx_abs = T.pad(img_gt_cmplx_abs[0,0,:,:],[256,256]).unsqueeze(0).unsqueeze(1).to(args.device) save_image(error_cmplx_abs,'Error') save_image(out_cmplx_abs, 'Recons') save_image(img_gt_cmplx_abs,'Target') break
def visualize(args, epoch, model_mag , model_pha, model_vs ,model_dun, data_loader, writer): def save_image(image, tag): image -= image.min() image /= image.max() grid = torchvision.utils.make_grid(image, nrow=4, pad_value=1) writer.add_image(tag, grid, epoch) model_mag.eval() model_pha.eval() model_vs.eval() model_dun.eval() with torch.no_grad(): print(':::::::::::::::: IN VISUALIZE :::::::::::::::::::::') for iter, data in enumerate(data_loader): mag_us,mag_gt,pha_us,pha_gt,ksp_us ,img_us,img_gt,img_us_np,img_gt_np,sens,mask,_,_ = data inp_mag = mag_us.unsqueeze(1).to(args.device) inp_pha = pha_us.unsqueeze(1).to(args.device) img_us_np = img_us_np.unsqueeze(1).float().to(args.device) img_gt_np = img_gt_np.unsqueeze(1).float().to(args.device) ksp_us = ksp_us.to(args.device) sens = sens.to(args.device) mask = mask.to(args.device) img_gt = img_gt.to(args.device) out_mag = model_mag(inp_mag) out_pha = model_pha(inp_pha) out_mag_unpad = T.unpad(out_mag , ksp_us) out_pha_unpad = T.unpad(out_pha , ksp_us) out_cmplx = T.dc(out_mag,out_pha,ksp_us,sens,mask) out_cmplx = model_vs(out_cmplx,ksp_us,sens,mask) # inp_dun = T.rss(out_cmplx,sens).float().to(args.device) ## takes rss inp_dun = T.inp_to_dun(out_cmplx).float().to(args.device) ## takes only mag out_dun = model_dun(inp_dun) save_image(inp_dun.float(), 'Input') err_dun = torch.abs(out_dun - img_gt_np) save_image(err_dun,'Error') save_image(out_dun, 'Recons') save_image(img_gt_np, 'Target') break
def evaluate(args, epoch, model_mag , model_pha , model_vs , data_loader, writer): model_mag.eval() model_pha.eval() model_vs.eval() losses_mag = [] losses_pha = [] losses_cmplx = [] start = time.perf_counter() with torch.no_grad(): for iter, data in enumerate(tqdm(data_loader)): mag_us,mag_gt,pha_us,pha_gt,ksp_us ,img_us,img_gt,img_us_np,img_gt_np,sens,mask,_,_ = data inp_mag = mag_us.unsqueeze(1).to(args.device) tgt_mag = mag_gt.unsqueeze(1).to(args.device) inp_pha = pha_us.unsqueeze(1).to(args.device) tgt_pha = pha_gt.unsqueeze(1).to(args.device) ksp_us = ksp_us.to(args.device) sens = sens.to(args.device) mask = mask.to(args.device) img_gt = img_gt.to(args.device) out_mag = model_mag(inp_mag) out_pha = model_pha(inp_pha) out_mag_unpad = T.unpad(out_mag , ksp_us) out_pha_unpad = T.unpad(out_pha , ksp_us) out_cmplx = T.dc(out_mag,out_pha,ksp_us,sens,mask) if (epoch >= 2): out_cmplx = model_vs(out_cmplx,ksp_us,sens,mask) loss_mag = F.mse_loss(out_mag , tgt_mag) loss_pha = F.mse_loss(out_pha , tgt_pha) loss_cmplx = F.mse_loss(out_cmplx,img_gt.to(args.device)) losses_mag.append(loss_mag.item()) losses_pha.append(loss_pha.item()) losses_cmplx.append(loss_cmplx.item()) writer.add_scalar('Dev_Loss_Mag', loss_mag, epoch) writer.add_scalar('Dev_Loss_Pha', loss_pha, epoch) writer.add_scalar('Dev_Loss_cmplx',loss_cmplx, epoch) return np.mean(losses_mag), np.mean(losses_pha), np.mean(losses_cmplx) , time.perf_counter() - start
def run_submission(model_mag, model_pha, model_vs, model_dun, data_loader): model_mag.eval() model_pha.eval() model_vs.eval() model_dun.eval() reconstructions = defaultdict(list) for iter, data in enumerate(tqdm(data_loader)): mag_us, pha_us, ksp_us, sens, mask, fname, slice, max_mag = data inp_mag = mag_us.unsqueeze(1).cuda() inp_pha = pha_us.unsqueeze(1).cuda() ksp_us = ksp_us.cuda() sens = sens.cuda() mask = mask.cuda() out_mag = model_mag(inp_mag) out_pha = model_pha(inp_pha) out_mag_unpad = T.unpad(out_mag, ksp_us) out_pha_unpad = T.unpad(out_pha, ksp_us) out_cmplx = T.dc(out_mag, out_pha, ksp_us, sens, mask) out_cmplx = model_vs(out_cmplx, ksp_us, sens, mask) # inp_dun = T.rss(out_cmplx,sens).float().to(args.device) ## takes rss inp_dun = T.inp_to_dun(out_cmplx).float().cuda() ## takes only mag out_dun = model_dun(inp_dun) #.squeeze(0).squeeze(0) out_dun = T.unpad(out_dun, ksp_us) out_dun = out_dun.detach().cpu() out_dun = out_dun.squeeze(1) #.squeeze(0) out_dun = out_dun * max_mag.float() # print("out_dun",out_dun.shape) for i in range(1): # recons[i] = recons[i] * std[i] + mean[i] reconstructions[fname[i]].append( (slice[i].numpy(), out_dun[i].numpy())) reconstructions = { fname: np.stack([pred for _, pred in sorted(slice_preds)]) for fname, slice_preds in reconstructions.items() } return reconstructions
def reconstruct(args, model_mag, model_pha, model_vs, model_dun, data_loader): model_mag.eval() model_pha.eval() model_vs.eval() model_dun.eval() with torch.no_grad(): for iter, data in enumerate(tqdm(data_loader)): mag_us, mag_gt, pha_us, pha_gt, ksp_us, img_us, img_gt, img_us_np, img_gt_np, sens, mask, fname = data inp_mag = mag_us.unsqueeze(1).to(args.device) inp_pha = pha_us.unsqueeze(1).to(args.device) img_us_np = img_us_np.unsqueeze(1).float().to(args.device) img_gt_np = img_gt_np.unsqueeze(1).float().to(args.device) ksp_us = ksp_us.to(args.device) sens = sens.to(args.device) mask = mask.to(args.device) img_gt = img_gt.to(args.device) out_mag = model_mag(inp_mag) out_pha = model_pha(inp_pha) out_mag_unpad = T.unpad(out_mag, ksp_us) out_pha_unpad = T.unpad(out_pha, ksp_us) out_cmplx = T.dc(out_mag, out_pha, ksp_us, sens, mask) out_cmplx = model_vs(out_cmplx, ksp_us, sens, mask) # inp_dun = T.rss(out_cmplx,sens).float().to(args.device) ## takes rss inp_dun = T.inp_to_dun(out_cmplx).float().to( args.device) ## takes only mag out_dun = model_dun(inp_dun).squeeze(0) out_dun = out_dun.cpu() fname = (pathlib.Path(str(fname))) parts = list(fname.parts) parts[-2] = 'Recons/acc_' + str(args.acceleration) + 'x' path = pathlib.Path(*parts) path = str(path)[2:-2] with h5py.File(str(path), 'w') as f: f.create_dataset("Recons", data=out_dun)
def evaluate(args, epoch, model_mag , model_pha , model_vs , model_dun , data_loader, writer): model_mag.eval() model_pha.eval() model_vs.eval() model_dun.eval() losses_dun = [] start = time.perf_counter() with torch.no_grad(): print(':::::::::::::::: IN EVALUATE :::::::::::::::::::::') for iter, data in enumerate(tqdm(data_loader)): mag_us,mag_gt,pha_us,pha_gt,ksp_us ,img_us,img_gt,img_us_np,img_gt_np,sens,mask,_,_ = data inp_mag = mag_us.unsqueeze(1).to(args.device) inp_pha = pha_us.unsqueeze(1).to(args.device) img_us_np = img_us_np.unsqueeze(1).float().to(args.device) img_gt_np = img_gt_np.unsqueeze(1).float().to(args.device) ksp_us = ksp_us.to(args.device) sens = sens.to(args.device) mask = mask.to(args.device) img_gt = img_gt.to(args.device) out_mag = model_mag(inp_mag) out_pha = model_pha(inp_pha) out_mag_unpad = T.unpad(out_mag , ksp_us) out_pha_unpad = T.unpad(out_pha , ksp_us) out_cmplx = T.dc(out_mag,out_pha,ksp_us,sens,mask) out_cmplx = model_vs(out_cmplx,ksp_us,sens,mask) # inp_dun = T.rss(out_cmplx,sens).float().to(args.device) ## takes rss inp_dun = T.inp_to_dun(out_cmplx).float().to(args.device) ## takes only mag out_dun = model_dun(inp_dun) # loss_dun = F.mse_loss(out_dun , img_gt_np) loss_dun = ssim_loss(out_dun, img_gt_np,torch.tensor(1.0).unsqueeze(0).cuda()) losses_dun.append(loss_dun.item()) writer.add_scalar('Dev_Loss_cmplx',np.mean(losses_dun), epoch) return np.mean(losses_dun) , time.perf_counter() - start
def train_epoch(args, epoch, model_mag, model_pha,model_vs, data_loader, optimizer_mag, optimizer_pha,optimizer_vs, writer): model_mag.train() model_pha.train() model_vs.train() avg_loss_mag = 0. avg_loss_pha = 0. avg_loss_cmplx = 0. start_epoch = start_iter = time.perf_counter() global_step = epoch * len(data_loader) for iter, data in enumerate(tqdm(data_loader)): mag_us,mag_gt,pha_us,pha_gt,ksp_us ,img_us,img_gt,img_us_np,img_gt_np,sens,mask,_,_ = data # input_kspace = input_kspace.to(args.device) inp_mag = mag_us.unsqueeze(1).to(args.device) tgt_mag = mag_gt.unsqueeze(1).to(args.device) inp_pha = pha_us.unsqueeze(1).to(args.device) tgt_pha = pha_gt.unsqueeze(1).to(args.device) # target = target.unsqueeze(1).to(args.device) ksp_us = ksp_us.to(args.device) sens = sens.to(args.device) mask = mask.to(args.device) img_gt = img_gt.to(args.device) out_mag = model_mag(inp_mag) out_pha = model_pha(inp_pha) out_mag_unpad = T.unpad(out_mag , ksp_us) out_pha_unpad = T.unpad(out_pha , ksp_us) out_cmplx = T.dc(out_mag,out_pha,ksp_us,sens,mask) if (epoch < 1): model_vs.eval() loss_mag = F.mse_loss(out_mag, tgt_mag) loss_pha = F.mse_loss(out_pha ,tgt_pha) loss_cmplx = F.mse_loss(out_cmplx,img_gt.to(args.device)) optimizer_mag.zero_grad() optimizer_pha.zero_grad() loss_mag.backward() loss_pha.backward() optimizer_mag.step() optimizer_pha.step() avg_loss_mag = 0.99 * avg_loss_mag + 0.01 * loss_mag.item() if iter > 0 else loss_mag.item() avg_loss_pha = 0.99 * avg_loss_pha + 0.01 * loss_pha.item() if iter > 0 else loss_pha.item() avg_loss_cmplx = 0.99 * avg_loss_cmplx + 0.01 * loss_cmplx.item() if iter > 0 else loss_cmplx.item() writer.add_scalar('TrainLoss_mag', loss_mag.item(), global_step + iter) writer.add_scalar('TrainLoss_pha', loss_pha.item(), global_step + iter) writer.add_scalar('TrainLoss_cmplx', loss_cmplx.item(), global_step + iter) if iter % args.report_interval == 0: logging.info( f'Epoch = [{epoch:3d}/{args.num_epochs:3d}] ' f'Iter = [{iter:4d}/{len(data_loader):4d}] ' f'Loss_mag = {loss_mag.item():.4g} Avg Loss Mag = {avg_loss_mag:.4g} ' f'Loss_pha = {loss_pha.item():.4g} Avg Loss Pha = {avg_loss_pha:.4g} ' f'Loss_cmplx = {loss_cmplx.item():.4g} Avg Loss cmplx = {avg_loss_cmplx:.4g} ' f'Time = {time.perf_counter() - start_iter:.4f}s', ) start_iter = time.perf_counter() elif (epoch >= 1 and epoch < 2): model_mag.train() model_pha.train() model_vs.eval() loss_mag = F.mse_loss(out_mag, tgt_mag) loss_pha = F.mse_loss(out_pha ,tgt_pha) loss_cmplx = F.mse_loss(out_cmplx,img_gt.to(args.device)) optimizer_mag.zero_grad() optimizer_pha.zero_grad() loss_cmplx.backward() optimizer_mag.step() optimizer_pha.step() avg_loss_mag = 0.99 * avg_loss_mag + 0.01 * loss_mag.item() if iter > 0 else loss_mag.item() avg_loss_pha = 0.99 * avg_loss_pha + 0.01 * loss_pha.item() if iter > 0 else loss_pha.item() avg_loss_cmplx = 0.99 * avg_loss_cmplx + 0.01 * loss_cmplx.item() if iter > 0 else loss_cmplx.item() writer.add_scalar('TrainLoss_mag', loss_mag.item(), global_step + iter) writer.add_scalar('TrainLoss_pha', loss_pha.item(), global_step + iter) writer.add_scalar('TrainLoss_cmplx', loss_cmplx.item(), global_step + iter) if iter % args.report_interval == 0: logging.info( f'Epoch = [{epoch:3d}/{args.num_epochs:3d}] ' f'Iter = [{iter:4d}/{len(data_loader):4d}] ' f'Loss_mag = {loss_mag.item():.4g} Avg Loss Mag = {avg_loss_mag:.4g} ' f'Loss_pha = {loss_pha.item():.4g} Avg Loss Pha = {avg_loss_pha:.4g} ' f'Loss_cmplx = {loss_cmplx.item():.4g} Avg Loss cmplx = {avg_loss_cmplx:.4g} ' f'Time = {time.perf_counter() - start_iter:.4f}s', ) start_iter = time.perf_counter() elif (epoch >= 2 and epoch < 3): model_mag.eval() model_pha.eval() model_vs.train() loss_mag = F.mse_loss(out_mag, tgt_mag) loss_pha = F.mse_loss(out_pha ,tgt_pha) out_cmplx = model_vs(out_cmplx,ksp_us,sens,mask) loss_cmplx = F.mse_loss(out_cmplx,img_gt.cuda()) optimizer_vs.zero_grad() loss_cmplx.backward() optimizer_vs.step() avg_loss_mag = 0.99 * avg_loss_mag + 0.01 * loss_mag.item() if iter > 0 else loss_mag.item() avg_loss_pha = 0.99 * avg_loss_pha + 0.01 * loss_pha.item() if iter > 0 else loss_pha.item() avg_loss_cmplx = 0.99 * avg_loss_cmplx + 0.01 * loss_cmplx.item() if iter > 0 else loss_cmplx.item() writer.add_scalar('TrainLoss_mag', loss_mag.item(), global_step + iter) writer.add_scalar('TrainLoss_pha', loss_pha.item(), global_step + iter) writer.add_scalar('TrainLoss_cmplx', loss_cmplx.item(), global_step + iter) if iter % args.report_interval == 0: logging.info( f'Epoch = [{epoch:3d}/{args.num_epochs:3d}] ' f'Iter = [{iter:4d}/{len(data_loader):4d}] ' f'Loss_mag = {loss_mag.item():.4g} Avg Loss Mag = {avg_loss_mag:.4g} ' f'Loss_pha = {loss_pha.item():.4g} Avg Loss Pha = {avg_loss_pha:.4g} ' f'Loss_cmplx = {loss_cmplx.item():.4g} Avg Loss cmplx = {avg_loss_cmplx:.4g} ' f'Time = {time.perf_counter() - start_iter:.4f}s', ) start_iter = time.perf_counter() else: model_mag.train() model_pha.train() model_vs.train() out_cmplx = model_vs(out_cmplx,ksp_us,sens,mask) loss_mag = F.mse_loss(out_mag, tgt_mag) loss_pha = F.mse_loss(out_pha ,tgt_pha) loss_cmplx = F.mse_loss(out_cmplx,img_gt.cuda()) optimizer_mag.zero_grad() optimizer_pha.zero_grad() optimizer_vs.zero_grad() loss_cmplx.backward() optimizer_mag.step() optimizer_pha.step() optimizer_vs.step() avg_loss_mag = 0.99 * avg_loss_mag + 0.01 * loss_mag.item() if iter > 0 else loss_mag.item() avg_loss_pha = 0.99 * avg_loss_pha + 0.01 * loss_pha.item() if iter > 0 else loss_pha.item() avg_loss_cmplx = 0.99 * avg_loss_cmplx + 0.01 * loss_cmplx.item() if iter > 0 else loss_cmplx.item() writer.add_scalar('TrainLoss_mag', loss_mag.item(), global_step + iter) writer.add_scalar('TrainLoss_pha', loss_pha.item(), global_step + iter) writer.add_scalar('TrainLoss_cmplx', loss_cmplx.item(), global_step + iter) if iter % args.report_interval == 0: logging.info( f'Epoch = [{epoch:3d}/{args.num_epochs:3d}] ' f'Iter = [{iter:4d}/{len(data_loader):4d}] ' f'Loss_mag = {loss_mag.item():.4g} Avg Loss Mag = {avg_loss_mag:.4g} ' f'Loss_pha = {loss_pha.item():.4g} Avg Loss Pha = {avg_loss_pha:.4g} ' f'Loss_cmplx = {loss_cmplx.item():.4g} Avg Loss cmplx = {avg_loss_cmplx:.4g} ' f'Time = {time.perf_counter() - start_iter:.4f}s', ) start_iter = time.perf_counter() return avg_loss_mag , avg_loss_pha , avg_loss_cmplx, time.perf_counter() - start_epoch
def train_epoch(args, epoch, model_mag,model_pha , model_vs, model_dun , data_loader, optimizer_dun, writer): model_mag.eval() model_pha.eval() model_vs.eval() model_dun.eval() avg_loss_dun = 0. start_epoch = start_iter = time.perf_counter() global_step = epoch * len(data_loader) for iter, data in enumerate(tqdm(data_loader)): mag_us,mag_gt,pha_us,pha_gt,ksp_us ,img_us,img_gt,img_us_np,img_gt_np,sens,mask,_,_ = data inp_mag = mag_us.unsqueeze(1).to(args.device) inp_pha = pha_us.unsqueeze(1).to(args.device) img_us_np = img_us_np.unsqueeze(1).float().to(args.device) img_gt_np = img_gt_np.unsqueeze(1).float().to(args.device) ksp_us = ksp_us.to(args.device) sens = sens.to(args.device) mask = mask.to(args.device) img_gt = img_gt.to(args.device) out_mag = model_mag(inp_mag) out_pha = model_pha(inp_pha) out_mag_unpad = T.unpad(out_mag , ksp_us) out_pha_unpad = T.unpad(out_pha , ksp_us) out_cmplx = T.dc(out_mag,out_pha,ksp_us,sens,mask) out_cmplx = model_vs(out_cmplx,ksp_us,sens,mask) # inp_dun = T.rss(out_cmplx,sens).float().to(args.device) ## takes rss inp_dun = T.inp_to_dun(out_cmplx).float().to(args.device) ## takes only mag out_dun = model_dun(inp_dun) # loss_dun = F.mse_loss(out_dun , img_gt_np) loss_dun = ssim_loss(out_dun, img_gt_np,torch.tensor(1.0).unsqueeze(0).cuda()) optimizer_dun.zero_grad() loss_dun.backward() optimizer_dun.step() avg_loss_dun = 0.99 * avg_loss_dun + 0.01 * loss_dun.item() if iter > 0 else loss_dun.item() writer.add_scalar('TrainLoss_cmplx', loss_dun.item(), global_step + iter) if iter % args.report_interval == 0: logging.info( f'Epoch = [{epoch:3d}/{args.num_epochs:3d}] ' f'Iter = [{iter:4d}/{len(data_loader):4d}] ' f'Loss_dun = {loss_dun.item():.4g} Avg Loss dun = {avg_loss_dun:.4g} ' f'Time = {time.perf_counter() - start_iter:.4f}s', ) start_iter = time.perf_counter() return avg_loss_dun, time.perf_counter() - start_epoch