Exemple #1
0
def run_sn(args, data_loader, model):
    """ Run Sigmanet """
    model.eval()
    logging.info(f'Run Sigmanet reconstruction')
    logging.info(f'Arguments: {args}')
    reconstructions = defaultdict(list)
    # keys = ['input', 'kspace', 'smaps', 'mask', 'fg_mask']
    # if args.mask_bg:
    #     keys.append('input_rss_mean')
    # attr_keys = ['mean', 'cov', 'norm']

    with torch.no_grad():
        for ii, sample in enumerate(tqdm(iter(data_loader))):
            sample = data_batch._read_data(sample, device=args.device)

            rec_x = sample['attrs']['metadata']['rec_x']
            rec_y = sample['attrs']['metadata']['rec_y']

            x = model(sample['input'], sample['kspace'], sample['smaps'],
                      sample['mask'], sample['attrs'])

            recons = postprocess(x, (rec_x, rec_y))

            # mask background using background mean value
            if args.mask_bg:
                fg_mask = center_crop(
                    sample['fg_mask'],
                    (rec_x, rec_y),
                ).squeeze(1)
                if args.use_bg_noise_mean:
                    bg_mean = sample['input_rss_mean'].reshape(-1, 1, 1)
                    recons = recons * fg_mask + (1 - fg_mask) * bg_mean
                else:
                    recons = recons * fg_mask

            # renormalize
            norm = sample['attrs']['norm'].reshape(-1, 1, 1)
            recons = recons * norm

            recons = recons.to('cpu').numpy()

            if args.debug and ii % 10 == 0:
                plt.imsave(
                    'run_sn_progress.png',
                    np.hstack(recons),
                    cmap='gray',
                )

            for bidx in range(recons.shape[0]):
                reconstructions[sample['fname']].append(
                    (sample['slidx'][bidx], recons[bidx]))

    reconstructions = {
        fname: np.stack([pred for _, pred in sorted(slice_preds)])
        for fname, slice_preds in reconstructions.items()
    }

    save_reconstructions(reconstructions, args.out_dir)
Exemple #2
0
def run_zero_filled_sense(args, data_loader):
    """ Run Adjoint (zero-filled SENSE) reconstruction """
    logging.info('Run zero-filled SENSE reconstruction')
    logging.info(f'Arguments: {args}')
    reconstructions = defaultdict(list)

    with torch.no_grad():
        for sample in tqdm(iter(data_loader)):
            sample = data_batch._read_data(sample)

            rec_x = sample['attrs']['metadata']['rec_x']
            rec_y = sample['attrs']['metadata']['rec_y']

            x = sample['input']

            recons = postprocess(x, (rec_x, rec_y))

            # mask background using background mean value
            if args.mask_bg:
                fg_mask = center_crop(
                    sample['fg_mask'],
                    (rec_x, rec_y),
                ).squeeze(1)
                if args.use_bg_noise_mean:
                    bg_mean = sample['input_rss_mean'].reshape(-1, 1, 1)
                    recons = recons * fg_mask + (1 - fg_mask) * bg_mean
                else:
                    recons = recons * fg_mask

            # renormalize
            norm = sample['attrs']['norm'].numpy()[:, np.newaxis, np.newaxis]
            recons = recons.numpy() * norm

            for bidx in range(recons.shape[0]):
                reconstructions[sample['fname']].append(
                    (sample['slidx'][bidx], recons[bidx]))

    reconstructions = {
        fname: np.stack([pred for _, pred in sorted(slice_preds)])
        for fname, slice_preds in reconstructions.items()
    }

    save_reconstructions(reconstructions, args.out_dir)
Exemple #3
0
def evaluate(state, model, data_loader, metrics, writer):
    model.eval()
    keys = ['input', 'target', 'kspace', 'smaps', 'mask', 'fg_mask']
    attr_keys = ['mean', 'cov', 'ref_max']
    losses = defaultdict(list)

    start = time.perf_counter()
    with torch.no_grad():
        for iter, sample in enumerate(data_loader):
            sample = data_mc._read_data(sample, keys, attr_keys, args.device)
            output = model(
                sample['input'],
                sample['kspace'],
                sample['smaps'],
                sample['mask'],
                sample['attrs'],
            )

            rec_x = sample['attrs']['metadata']['rec_x']
            rec_y = sample['attrs']['metadata']['rec_y']

            output = postprocess(output, (rec_x, rec_y))
            target = postprocess(sample['target'], (rec_x, rec_y))
            sample['fg_mask'] = postprocess(sample['fg_mask'], (rec_x, rec_y))
            loss = state.loss_fn(
                output=output,
                target=target,
                sample=sample,
                scale=1. / state.grad_acc,
            )[0]
            losses['dev_loss'].append(loss.item())

            # evaluate in the foreground
            target = target.unsqueeze(1) * sample['fg_mask']
            output = output.unsqueeze(1) * sample['fg_mask']
            for k in metrics:
                losses[k].append(metrics[k](target, output, sample).item())

        for k in losses:
            writer.add_scalar(f'Dev_{k}', np.mean(losses[k]), state.epoch)

    return losses, time.perf_counter() - start
Exemple #4
0
def visualize(state, model, data_loader, writer):
    save_image = functools.partial(save_image_writer, writer, state.epoch)
    keys = ['input', 'target', 'kspace', 'smaps', 'mask', 'fg_mask']
    attr_keys = ['mean', 'cov', 'ref_max']
    model.eval()
    with torch.no_grad():
        for iter, sample in enumerate(data_loader):
            sample = data_mc._read_data(sample, keys, attr_keys, args.device)
            output = model(
                sample['input'],
                sample['kspace'],
                sample['smaps'],
                sample['mask'],
                sample['attrs'],
            )

            rec_x = sample['attrs']['metadata']['rec_x']
            rec_y = sample['attrs']['metadata']['rec_y']

            output = postprocess(output, (rec_x, rec_y))

            input = postprocess(sample['input'], (rec_x, rec_y))
            target = postprocess(sample['target'], (rec_x, rec_y))
            fg_mask = postprocess(sample['fg_mask'], (rec_x, rec_y))

            base_err = torch.abs(target - input)
            pred_err = torch.abs(target - output)
            residual = torch.abs(input - output)
            save_image(
                torch.cat([input, output, target], -1).unsqueeze(0),
                'und_pred_gt',
            )
            save_image(
                torch.cat([base_err, pred_err, residual], -1).unsqueeze(0),
                'Err_base_pred',
                base_err.max(),
            )
            save_image(fg_mask, 'Mask', 1.)

            break
Exemple #5
0
def train_epoch(state, model, data_loader, optimizer, writer):
    model.train()
    args = state.args
    keys = ['input', 'target', 'kspace', 'smaps', 'mask', 'fg_mask']
    attr_keys = ['mean', 'cov', 'ref_max']
    save_image = functools.partial(save_image_writer, writer, state.epoch)
    avg_loss = 0.
    start_epoch = start_iter = time.perf_counter()

    perf_avg = 0
    for iter, sample in enumerate(data_loader):
        t0 = time.perf_counter()
        sample = data_mc._read_data(sample, keys, attr_keys, args.device)
        output = model(
            sample['input'],
            sample['kspace'],
            sample['smaps'],
            sample['mask'],
            sample['attrs'],
        )

        rec_x = sample['attrs']['metadata']['rec_x']
        rec_y = sample['attrs']['metadata']['rec_y']

        output = postprocess(output, (rec_x, rec_y))

        target = postprocess(sample['target'], (rec_x, rec_y))
        sample['fg_mask'] = postprocess(sample['fg_mask'], (rec_x, rec_y))
        loss, loss_l1, loss_ssim = state.loss_fn(
            output=output,
            target=target,
            sample=sample,
            scale=1. / args.grad_acc,
        )
        t1 = time.perf_counter()
        loss.backward()
        if state.global_step % state.grad_acc == 0:
            optimizer.step()
            optimizer.zero_grad()
        state.global_step += 1
        t2 = time.perf_counter()
        perf = t2 - start_iter
        perf_avg += perf

        avg_loss = 0.99 * avg_loss + (0.01 if iter > 0 else 1) * loss.item()
        if iter % args.report_interval == 0:
            writer.add_scalar('TrainLoss', loss.item(), state.global_step)
            writer.add_scalar('TrainL1Loss', loss_l1.item(), state.global_step)
            writer.add_scalar(
                'TrainSSIMLoss',
                loss_ssim.item(),
                state.global_step,
            )
            logging.info(f'Epoch = [{state.epoch:3d}/{args.num_epochs:3d}] '
                         f'Iter = [{iter:4d}/{len(data_loader):4d}] '
                         f'Loss = {loss.item():.4g} Avg Loss = {avg_loss:.4g} '
                         f't = (tot:{perf_avg/(iter+1):.1g}s'
                         f'/fwd:{t1-t0:.1g}/bwd:{t2-t1:.1g}s)')

        if state.global_step % 1000 == 0:

            rec_x = sample['attrs']['metadata']['rec_x']
            rec_y = sample['attrs']['metadata']['rec_y']

            input_abs = postprocess(sample['input'], (rec_x, rec_y))
            base_err = torch.abs(target - input_abs)
            pred_err = torch.abs(target - output)
            residual = torch.abs(input_abs - output)
            save_image(
                torch.cat([input_abs, output, target], -1).unsqueeze(0),
                'Train_und_pred_gt',
            )
            save_image(
                torch.cat([base_err, pred_err, residual], -1).unsqueeze(0),
                'Train_Err_base_pred',
                base_err.max(),
            )
            save_model(args,
                       args.exp_dir,
                       state.epoch,
                       model,
                       optimizer,
                       avg_loss,
                       is_new_best=False,
                       modelname='model_tmp.pt')

        start_iter = time.perf_counter()


#        if iter == 1000:
#            break

    return avg_loss, time.perf_counter() - start_epoch
Exemple #6
0
def train_epoch(state, model, D_model, data_loader, optimizer, D_opt, writer):
    model.train()
    D_model.train()
    args = state.args
    keys = ['input', 'target', 'kspace', 'smaps', 'mask', 'fg_mask']
    attr_keys = ['mean', 'cov', 'ref_max']
    save_image = functools.partial(save_image_writer, writer, state.epoch)
    avg_loss = 0.
    start_epoch = start_iter = time.perf_counter()

    perf_avg = 0

    n_critic = 5 * state.grad_acc
    adversarial_loss = torch.nn.MSELoss()
    cuda = True if torch.cuda.is_available() else False
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    valid = Variable(Tensor(1, 1).fill_(1.0), requires_grad=False)
    fake = Variable(Tensor(1, 1).fill_(0.0), requires_grad=False)
    update_D = True
    D_loss_epoch = []
    G_loss_epoch = []

    for iter, sample in enumerate(data_loader):
        t0 = time.perf_counter()
        sample = data_mc._read_data(sample, keys, attr_keys, args.device)
        if update_D:
            output = model(
                sample['input'],
                sample['kspace'],
                sample['smaps'],
                sample['mask'],
                sample['attrs'],
            )

            rec_x = sample['attrs']['metadata']['rec_x']
            rec_y = sample['attrs']['metadata']['rec_y']

            output = postprocess_gan(output, (rec_x, rec_y))

            target = postprocess_gan(sample['target'], (rec_x, rec_y))
            sample['fg_mask'] = postprocess(sample['fg_mask'], (rec_x, rec_y))
            # pdb.set_trace()

            # Adversarial loss
            x_outputs = D_model(sample['fg_mask'] * target)
            z_outputs = D_model(sample['fg_mask'] * output.detach())

            # LSGAN
            real_loss = adversarial_loss(x_outputs, valid)
            fake_loss = adversarial_loss(z_outputs, fake)
            D_loss = 0.5 * (real_loss + fake_loss)
            D_loss.backward(retain_graph=True)

            if state.D_grad_acc_step % state.grad_acc == 0:
                D_opt.step()
                D_opt.zero_grad()
            state.D_grad_acc_step += 1
            D_loss_epoch.append(D_loss.item())

        if state.D_grad_acc_step % n_critic == 0:
            update_D = False
            # Training Generator
            output = model(
                sample['input'],
                sample['kspace'],
                sample['smaps'],
                sample['mask'],
                sample['attrs'],
            )

            rec_x = sample['attrs']['metadata']['rec_x']
            rec_y = sample['attrs']['metadata']['rec_y']

            output = postprocess_gan(output, (rec_x, rec_y))
            sample['fg_mask'] = postprocess(sample['fg_mask'], (rec_x, rec_y))
            z_outputs = D_model(sample['fg_mask'] * output)

            if len(sample['target'].shape) == 5:
                target = postprocess_gan(sample['target'], (rec_x, rec_y))

            loss, loss_l1, loss_ssim = state.loss_fn(
                output=output[:, 0],
                target=target[:, 0],
                sample=sample,
                scale=1. / args.grad_acc,
            )
            # LSGAN
            G_loss = adversarial_loss(z_outputs, valid) + 0.1 * loss
            # pdb.set_trace()
            t1 = time.perf_counter()
            G_loss.backward()
            if state.global_step % state.grad_acc == 0:
                optimizer.step()
                optimizer.zero_grad()
                update_D = True

            t2 = time.perf_counter()
            G_loss_epoch.append(G_loss.item())
            perf = t2 - start_iter
            perf_avg += perf
        state.global_step += 1
        # avg_loss = 0.99 * avg_loss + (0.01 if iter > 0 else 1) * loss.item()
        if (iter + 1) % args.report_interval == 0:
            avg_loss = np.mean(G_loss_epoch)
            D_avg_loss = np.mean(D_loss_epoch)
            writer.add_scalar('DiscriminatorLoss', D_loss, state.global_step)
            writer.add_scalar('TrainGANLoss', G_loss.item(), state.global_step)
            writer.add_scalar('TrainLoss', loss.item(), state.global_step)
            writer.add_scalar('TrainL1Loss', loss_l1.item(), state.global_step)
            writer.add_scalar(
                'TrainSSIMLoss',
                loss_ssim.item(),
                state.global_step,
            )
            logging.info(
                f'Epoch = [{state.epoch:3d}/{args.num_epochs:3d}] '
                f'Iter = [{iter:4d}/{len(data_loader):4d}] '
                f'Loss = {loss.item():.4g} Avg Loss = {avg_loss:.4g} D_Loss = {D_avg_loss:.4g} '
                f't = (tot:{perf_avg/(iter+1):.1g}s'
                f'/fwd:{t1-t0:.1g}/bwd:{t2-t1:.1g}s)')

        if state.global_step % 1000 == 0:
            input_abs = postprocess_gan(sample['input'], (rec_x, rec_y))
            base_err = torch.abs(target - input_abs)
            pred_err = torch.abs(target - output)
            residual = torch.abs(input_abs - output)
            save_image(
                torch.cat([input_abs, output, target], -1).unsqueeze(0),
                'Train_und_pred_gt',
            )
            save_image(
                torch.cat([base_err, pred_err, residual], -1).unsqueeze(0),
                'Train_Err_base_pred',
                base_err.max(),
            )
            save_model(args,
                       args.exp_dir,
                       state.epoch,
                       model,
                       optimizer,
                       D_model,
                       D_opt,
                       avg_loss,
                       is_new_best=False,
                       modelname='model_tmp.pt')

        start_iter = time.perf_counter()


#        if iter == 1000:
#            break

    return avg_loss, time.perf_counter() - start_epoch