def run_sn(args, data_loader, model): """ Run Sigmanet """ model.eval() logging.info(f'Run Sigmanet reconstruction') logging.info(f'Arguments: {args}') reconstructions = defaultdict(list) # keys = ['input', 'kspace', 'smaps', 'mask', 'fg_mask'] # if args.mask_bg: # keys.append('input_rss_mean') # attr_keys = ['mean', 'cov', 'norm'] with torch.no_grad(): for ii, sample in enumerate(tqdm(iter(data_loader))): sample = data_batch._read_data(sample, device=args.device) rec_x = sample['attrs']['metadata']['rec_x'] rec_y = sample['attrs']['metadata']['rec_y'] x = model(sample['input'], sample['kspace'], sample['smaps'], sample['mask'], sample['attrs']) recons = postprocess(x, (rec_x, rec_y)) # mask background using background mean value if args.mask_bg: fg_mask = center_crop( sample['fg_mask'], (rec_x, rec_y), ).squeeze(1) if args.use_bg_noise_mean: bg_mean = sample['input_rss_mean'].reshape(-1, 1, 1) recons = recons * fg_mask + (1 - fg_mask) * bg_mean else: recons = recons * fg_mask # renormalize norm = sample['attrs']['norm'].reshape(-1, 1, 1) recons = recons * norm recons = recons.to('cpu').numpy() if args.debug and ii % 10 == 0: plt.imsave( 'run_sn_progress.png', np.hstack(recons), cmap='gray', ) for bidx in range(recons.shape[0]): reconstructions[sample['fname']].append( (sample['slidx'][bidx], recons[bidx])) reconstructions = { fname: np.stack([pred for _, pred in sorted(slice_preds)]) for fname, slice_preds in reconstructions.items() } save_reconstructions(reconstructions, args.out_dir)
def run_zero_filled_sense(args, data_loader): """ Run Adjoint (zero-filled SENSE) reconstruction """ logging.info('Run zero-filled SENSE reconstruction') logging.info(f'Arguments: {args}') reconstructions = defaultdict(list) with torch.no_grad(): for sample in tqdm(iter(data_loader)): sample = data_batch._read_data(sample) rec_x = sample['attrs']['metadata']['rec_x'] rec_y = sample['attrs']['metadata']['rec_y'] x = sample['input'] recons = postprocess(x, (rec_x, rec_y)) # mask background using background mean value if args.mask_bg: fg_mask = center_crop( sample['fg_mask'], (rec_x, rec_y), ).squeeze(1) if args.use_bg_noise_mean: bg_mean = sample['input_rss_mean'].reshape(-1, 1, 1) recons = recons * fg_mask + (1 - fg_mask) * bg_mean else: recons = recons * fg_mask # renormalize norm = sample['attrs']['norm'].numpy()[:, np.newaxis, np.newaxis] recons = recons.numpy() * norm for bidx in range(recons.shape[0]): reconstructions[sample['fname']].append( (sample['slidx'][bidx], recons[bidx])) reconstructions = { fname: np.stack([pred for _, pred in sorted(slice_preds)]) for fname, slice_preds in reconstructions.items() } save_reconstructions(reconstructions, args.out_dir)
def evaluate(state, model, data_loader, metrics, writer): model.eval() keys = ['input', 'target', 'kspace', 'smaps', 'mask', 'fg_mask'] attr_keys = ['mean', 'cov', 'ref_max'] losses = defaultdict(list) start = time.perf_counter() with torch.no_grad(): for iter, sample in enumerate(data_loader): sample = data_mc._read_data(sample, keys, attr_keys, args.device) output = model( sample['input'], sample['kspace'], sample['smaps'], sample['mask'], sample['attrs'], ) rec_x = sample['attrs']['metadata']['rec_x'] rec_y = sample['attrs']['metadata']['rec_y'] output = postprocess(output, (rec_x, rec_y)) target = postprocess(sample['target'], (rec_x, rec_y)) sample['fg_mask'] = postprocess(sample['fg_mask'], (rec_x, rec_y)) loss = state.loss_fn( output=output, target=target, sample=sample, scale=1. / state.grad_acc, )[0] losses['dev_loss'].append(loss.item()) # evaluate in the foreground target = target.unsqueeze(1) * sample['fg_mask'] output = output.unsqueeze(1) * sample['fg_mask'] for k in metrics: losses[k].append(metrics[k](target, output, sample).item()) for k in losses: writer.add_scalar(f'Dev_{k}', np.mean(losses[k]), state.epoch) return losses, time.perf_counter() - start
def visualize(state, model, data_loader, writer): save_image = functools.partial(save_image_writer, writer, state.epoch) keys = ['input', 'target', 'kspace', 'smaps', 'mask', 'fg_mask'] attr_keys = ['mean', 'cov', 'ref_max'] model.eval() with torch.no_grad(): for iter, sample in enumerate(data_loader): sample = data_mc._read_data(sample, keys, attr_keys, args.device) output = model( sample['input'], sample['kspace'], sample['smaps'], sample['mask'], sample['attrs'], ) rec_x = sample['attrs']['metadata']['rec_x'] rec_y = sample['attrs']['metadata']['rec_y'] output = postprocess(output, (rec_x, rec_y)) input = postprocess(sample['input'], (rec_x, rec_y)) target = postprocess(sample['target'], (rec_x, rec_y)) fg_mask = postprocess(sample['fg_mask'], (rec_x, rec_y)) base_err = torch.abs(target - input) pred_err = torch.abs(target - output) residual = torch.abs(input - output) save_image( torch.cat([input, output, target], -1).unsqueeze(0), 'und_pred_gt', ) save_image( torch.cat([base_err, pred_err, residual], -1).unsqueeze(0), 'Err_base_pred', base_err.max(), ) save_image(fg_mask, 'Mask', 1.) break
def train_epoch(state, model, data_loader, optimizer, writer): model.train() args = state.args keys = ['input', 'target', 'kspace', 'smaps', 'mask', 'fg_mask'] attr_keys = ['mean', 'cov', 'ref_max'] save_image = functools.partial(save_image_writer, writer, state.epoch) avg_loss = 0. start_epoch = start_iter = time.perf_counter() perf_avg = 0 for iter, sample in enumerate(data_loader): t0 = time.perf_counter() sample = data_mc._read_data(sample, keys, attr_keys, args.device) output = model( sample['input'], sample['kspace'], sample['smaps'], sample['mask'], sample['attrs'], ) rec_x = sample['attrs']['metadata']['rec_x'] rec_y = sample['attrs']['metadata']['rec_y'] output = postprocess(output, (rec_x, rec_y)) target = postprocess(sample['target'], (rec_x, rec_y)) sample['fg_mask'] = postprocess(sample['fg_mask'], (rec_x, rec_y)) loss, loss_l1, loss_ssim = state.loss_fn( output=output, target=target, sample=sample, scale=1. / args.grad_acc, ) t1 = time.perf_counter() loss.backward() if state.global_step % state.grad_acc == 0: optimizer.step() optimizer.zero_grad() state.global_step += 1 t2 = time.perf_counter() perf = t2 - start_iter perf_avg += perf avg_loss = 0.99 * avg_loss + (0.01 if iter > 0 else 1) * loss.item() if iter % args.report_interval == 0: writer.add_scalar('TrainLoss', loss.item(), state.global_step) writer.add_scalar('TrainL1Loss', loss_l1.item(), state.global_step) writer.add_scalar( 'TrainSSIMLoss', loss_ssim.item(), state.global_step, ) logging.info(f'Epoch = [{state.epoch:3d}/{args.num_epochs:3d}] ' f'Iter = [{iter:4d}/{len(data_loader):4d}] ' f'Loss = {loss.item():.4g} Avg Loss = {avg_loss:.4g} ' f't = (tot:{perf_avg/(iter+1):.1g}s' f'/fwd:{t1-t0:.1g}/bwd:{t2-t1:.1g}s)') if state.global_step % 1000 == 0: rec_x = sample['attrs']['metadata']['rec_x'] rec_y = sample['attrs']['metadata']['rec_y'] input_abs = postprocess(sample['input'], (rec_x, rec_y)) base_err = torch.abs(target - input_abs) pred_err = torch.abs(target - output) residual = torch.abs(input_abs - output) save_image( torch.cat([input_abs, output, target], -1).unsqueeze(0), 'Train_und_pred_gt', ) save_image( torch.cat([base_err, pred_err, residual], -1).unsqueeze(0), 'Train_Err_base_pred', base_err.max(), ) save_model(args, args.exp_dir, state.epoch, model, optimizer, avg_loss, is_new_best=False, modelname='model_tmp.pt') start_iter = time.perf_counter() # if iter == 1000: # break return avg_loss, time.perf_counter() - start_epoch
def train_epoch(state, model, D_model, data_loader, optimizer, D_opt, writer): model.train() D_model.train() args = state.args keys = ['input', 'target', 'kspace', 'smaps', 'mask', 'fg_mask'] attr_keys = ['mean', 'cov', 'ref_max'] save_image = functools.partial(save_image_writer, writer, state.epoch) avg_loss = 0. start_epoch = start_iter = time.perf_counter() perf_avg = 0 n_critic = 5 * state.grad_acc adversarial_loss = torch.nn.MSELoss() cuda = True if torch.cuda.is_available() else False Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor valid = Variable(Tensor(1, 1).fill_(1.0), requires_grad=False) fake = Variable(Tensor(1, 1).fill_(0.0), requires_grad=False) update_D = True D_loss_epoch = [] G_loss_epoch = [] for iter, sample in enumerate(data_loader): t0 = time.perf_counter() sample = data_mc._read_data(sample, keys, attr_keys, args.device) if update_D: output = model( sample['input'], sample['kspace'], sample['smaps'], sample['mask'], sample['attrs'], ) rec_x = sample['attrs']['metadata']['rec_x'] rec_y = sample['attrs']['metadata']['rec_y'] output = postprocess_gan(output, (rec_x, rec_y)) target = postprocess_gan(sample['target'], (rec_x, rec_y)) sample['fg_mask'] = postprocess(sample['fg_mask'], (rec_x, rec_y)) # pdb.set_trace() # Adversarial loss x_outputs = D_model(sample['fg_mask'] * target) z_outputs = D_model(sample['fg_mask'] * output.detach()) # LSGAN real_loss = adversarial_loss(x_outputs, valid) fake_loss = adversarial_loss(z_outputs, fake) D_loss = 0.5 * (real_loss + fake_loss) D_loss.backward(retain_graph=True) if state.D_grad_acc_step % state.grad_acc == 0: D_opt.step() D_opt.zero_grad() state.D_grad_acc_step += 1 D_loss_epoch.append(D_loss.item()) if state.D_grad_acc_step % n_critic == 0: update_D = False # Training Generator output = model( sample['input'], sample['kspace'], sample['smaps'], sample['mask'], sample['attrs'], ) rec_x = sample['attrs']['metadata']['rec_x'] rec_y = sample['attrs']['metadata']['rec_y'] output = postprocess_gan(output, (rec_x, rec_y)) sample['fg_mask'] = postprocess(sample['fg_mask'], (rec_x, rec_y)) z_outputs = D_model(sample['fg_mask'] * output) if len(sample['target'].shape) == 5: target = postprocess_gan(sample['target'], (rec_x, rec_y)) loss, loss_l1, loss_ssim = state.loss_fn( output=output[:, 0], target=target[:, 0], sample=sample, scale=1. / args.grad_acc, ) # LSGAN G_loss = adversarial_loss(z_outputs, valid) + 0.1 * loss # pdb.set_trace() t1 = time.perf_counter() G_loss.backward() if state.global_step % state.grad_acc == 0: optimizer.step() optimizer.zero_grad() update_D = True t2 = time.perf_counter() G_loss_epoch.append(G_loss.item()) perf = t2 - start_iter perf_avg += perf state.global_step += 1 # avg_loss = 0.99 * avg_loss + (0.01 if iter > 0 else 1) * loss.item() if (iter + 1) % args.report_interval == 0: avg_loss = np.mean(G_loss_epoch) D_avg_loss = np.mean(D_loss_epoch) writer.add_scalar('DiscriminatorLoss', D_loss, state.global_step) writer.add_scalar('TrainGANLoss', G_loss.item(), state.global_step) writer.add_scalar('TrainLoss', loss.item(), state.global_step) writer.add_scalar('TrainL1Loss', loss_l1.item(), state.global_step) writer.add_scalar( 'TrainSSIMLoss', loss_ssim.item(), state.global_step, ) logging.info( f'Epoch = [{state.epoch:3d}/{args.num_epochs:3d}] ' f'Iter = [{iter:4d}/{len(data_loader):4d}] ' f'Loss = {loss.item():.4g} Avg Loss = {avg_loss:.4g} D_Loss = {D_avg_loss:.4g} ' f't = (tot:{perf_avg/(iter+1):.1g}s' f'/fwd:{t1-t0:.1g}/bwd:{t2-t1:.1g}s)') if state.global_step % 1000 == 0: input_abs = postprocess_gan(sample['input'], (rec_x, rec_y)) base_err = torch.abs(target - input_abs) pred_err = torch.abs(target - output) residual = torch.abs(input_abs - output) save_image( torch.cat([input_abs, output, target], -1).unsqueeze(0), 'Train_und_pred_gt', ) save_image( torch.cat([base_err, pred_err, residual], -1).unsqueeze(0), 'Train_Err_base_pred', base_err.max(), ) save_model(args, args.exp_dir, state.epoch, model, optimizer, D_model, D_opt, avg_loss, is_new_best=False, modelname='model_tmp.pt') start_iter = time.perf_counter() # if iter == 1000: # break return avg_loss, time.perf_counter() - start_epoch