Example #1
0
def train_step(model, data, device):
    input, target, mean, std, norm, _, mean_abs, std_abs = data
    input = input.to(device)
    target = target.to(device)
    output = model(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
    mean = mean.to(device)
    std = std.to(device)

    if TRAIN_COMPLEX and not RENORM:
        output = transforms.unnormalize(output, mean, std)

    elif not TRAIN_COMPLEX:
        output = transforms.unnormalize(output, mean, std)
        output = transforms.complex_abs(output)
        if RENORM:
            mean_abs = mean_abs.unsqueeze(1).unsqueeze(2).to(device)
            std_abs = std_abs.unsqueeze(1).unsqueeze(2).to(device)
            output = transforms.normalize(output, mean_abs, std_abs)

    loss_f = F.smooth_l1_loss if SMOOTH else F.l1_loss
    loss = loss_f(output, target)
    if RENORM:
        return loss
    else:
        return 1e9 * loss
Example #2
0
def train_step(model, data, device):
    input, target, mean, std, mean_image, std_image, mask = data
    input = input.to(device)
    mask = mask.to(device)
    target = target.to(device)
    output = model(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)

    # Projection to consistent K-space
    output = input * mask + (1-mask) * output
    
    # Consistent K-space loss (with the normalized output and target)
    loss_k_consistent = F.l1_loss(output, target) 

    mean = mean.to(device)
    std = std.to(device)

    target = transforms.unnormalize(target, mean, std)
    output = transforms.unnormalize(output, mean, std)

    output_image = transforms.ifft2(output)
    target_image = transforms.ifft2(target)

    output_image = transforms.complex_center_crop(output_image, (320, 320))
    output_image = transforms.complex_abs(output_image)
    target_image = transforms.complex_center_crop(target_image, (320, 320))
    target_image = transforms.complex_abs(target_image)
    mean_image = mean_image.unsqueeze(1).unsqueeze(2).to(device)
    std_image = std_image.unsqueeze(1).unsqueeze(2).to(device)
    output_image = transforms.normalize(output_image, mean_image, std_image)
    target_image = transforms.normalize(target_image, mean_image, std_image)
    target_image = target_image.clamp(-6, 6)
    # Consistent image loss (with the unnormalized output and target)
    loss_image = F.l1_loss(output_image, target_image)
    loss = loss_k_consistent + loss_image
    return loss
Example #3
0
def generate(generator, data, device):

    input, _, mean, std, mask, _, _, _ = data
    input = input.to(device)
    mask = mask.to(device)

    output_network = generator(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)

    # Projection to consistent K-space
    output_consistent, target_kspace, output_kspace = project_to_consistent_subspace(
        output_network, input, mask)

    # Take loss on the cropped, real valued image (abs)
    mean = mean.to(device)
    std = std.to(device)
    mean = mean.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(device)
    std = std.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(device)
    output_consistent = transforms.unnormalize(output_consistent, mean, std)
    output_consistent = transforms.complex_center_crop(output_consistent,
                                                       (320, 320))
    output_consistent = transforms.complex_abs(output_consistent)

    output_network = transforms.unnormalize(output_network, mean, std)
    output_network = transforms.complex_center_crop(output_network, (320, 320))
    output_network = transforms.complex_abs(output_network)

    return output_consistent, output_network, target_kspace, output_kspace
Example #4
0
def visualize(args, epoch, model, inference, data_loader, writer):
    
    def save_image(image, tag):
        image -= image.min()
        image /= image.max()
        grid = torchvision.utils.make_grid(image, nrow=4, pad_value=1)
        writer.add_image(tag, grid, epoch)

    model.eval()
    with torch.no_grad():
        for iter, data in enumerate(data_loader):
            output, target = inference(model, data, device=args.device)

            # HACK to make images look good in tensorboard
            output, mean_o, std_o = transforms.normalize_instance(output)
            output = output.clamp(-6, 6)
            output = transforms.unnormalize(output, mean_o, std_o)
            target, mean_t, std_t = transforms.normalize_instance(target)
            target = target.clamp(-6, 6)
            target = transforms.unnormalize(target, mean_t, std_t)

            output = output.unsqueeze(1) # [batch_sz, h, w] --> [batch_sz, 1, h, w]
            target = target.unsqueeze(1) # [batch_sz, h, w] --> [batch_sz, 1, h, w]
            if isinstance(output, dict):
                for k, output_val in output.items():
                    # save_image(input, 'Input_{}'.format(k))
                    save_image(target, 'Target_{}'.format(k))
                    save_image(output, 'Reconstruction_{}'.format(k))
                    save_image(torch.abs(target - output), 'Error_{}'.format(k))
            else:
                # save_image(input, 'Input')
                save_image(target, 'Target')
                save_image(output, 'Reconstruction')
                save_image(torch.abs(target - output), 'Error')
            break
Example #5
0
def inference(model, data, device):
    input, _, mean, std, _, target = data
    input = input.unsqueeze(1).to(device)
    target = target.to(device)
    output = model(input).squeeze(1)

    mean = mean.unsqueeze(1).unsqueeze(2).to(device)
    std = std.unsqueeze(1).unsqueeze(2).to(device)
    
    target = transforms.unnormalize(target, mean, std)
    output = transforms.unnormalize(output, mean, std)
    return output, target
Example #6
0
    def inference(self, model, data, device):
        input, _, mean, std, _, target = data
        input = input.unsqueeze(1).to(device)
        target = target.to(device)
        output = model(input).squeeze(1)
        output, sigmas = output[:,
                                0, :, :], torch.exp(output[:, 1, :, :]) + 1e-4

        mean = mean.unsqueeze(1).unsqueeze(2).to(device)
        std = std.unsqueeze(1).unsqueeze(2).to(device)

        target = transforms.unnormalize(target, mean, std)
        output = transforms.unnormalize(output, mean, std)
        sigmas = transforms.unnormalize(sigmas, mean, std)
        confidence = -(sigmas**2).sum(dim=2).sum(dim=1)
        return output, target, confidence, sigmas
Example #7
0
def inference(model, data, device):
    input, _, mean, std, _, target, mean_abs, std_abs = data
    input = input.to(device)
    target = target.to(device)
    output = model(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)

    mean = mean.to(device)
    std = std.to(device)
    mean_abs = mean_abs.unsqueeze(1).unsqueeze(2).to(device)
    std_abs = std_abs.unsqueeze(1).unsqueeze(2).to(device)

    output = transforms.unnormalize(output, mean, std)
    output = transforms.complex_abs(output)
    return output, target
Example #8
0
def inference(model, data, device):
    with torch.no_grad():
        input, target, mean, std, _, _, mask = data
        input = input.to(device)
        mask = mask.to(device)
        output = model(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
        output = input * mask + (1-mask) * output
        target = target.to(device)

        mean = mean.to(device)
        std = std.to(device)

        output = transforms.unnormalize(output, mean, std)
        target = transforms.unnormalize(target, mean, std)

        output = transforms.ifft2(output)
        target = transforms.ifft2(target)

        output = transforms.complex_center_crop(output, (320, 320))
        output = transforms.complex_abs(output)
        target = transforms.complex_center_crop(target, (320, 320))
        target = transforms.complex_abs(target)

        return output, target
Example #9
0
def generate(generator, data, device):

    input, _, mean, std, mask, _, _, _ = data
    input = input.to(device)
    mask = mask.to(device)

    output_network = generator(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)

    # Take loss on the cropped, real valued image (abs)
    mean = mean.to(device)
    std = std.to(device)
    output_network = transforms.unnormalize(output_network, mean, std)
    output_network = transforms.complex_center_crop(output_network, (320, 320))
    output_network = transforms.complex_abs(output_network)

    return output_network
Example #10
0
def generate(generator, data, device):

    input, _, mean, std, mask, _, _, _ = data
    input = input.to(device)
    mask = mask.to(device)

    # Use network to predict residual
    residual = generator(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)

    # Projection to consistent K-space
    if PROJECT:
        output = project_to_consistent_subspace(residual, input, mask)

    # Take loss on the cropped, real valued image (abs)
    mean = mean.to(device)
    std = std.to(device)
    output = transforms.unnormalize(output, mean, std)
    output = transforms.complex_center_crop(output, (320, 320))
    output = transforms.complex_abs(output)

    return output
Example #11
0
def inference(model, data, device):
    input, target, mean, std, norm, unnormalized_target, image_updated = data
    if len(target) != 0:
        target, _, _ = transforms.normalize_instance(target, eps=1e-11)
    if len(image_updated) != 0:
        input = image_updated
    input, mean, std = transforms.normalize_instance(input, eps=1e-11)
    if CLAMP:
        input = input.clamp(-6, 6)

    input = input.unsqueeze(0).unsqueeze(1).to(device)
    if len(unnormalized_target) != 0:
        unnormalized_target = unnormalized_target.to(device)
    output = model(input).squeeze(1).squeeze(0)

    mean = mean.unsqueeze(0).unsqueeze(1).unsqueeze(2).to(device)
    std = std.unsqueeze(0).unsqueeze(1).unsqueeze(2).to(device)
    output = transforms.unnormalize(output, mean, std)
    # if len(target) != 0:
    #     target = transforms.unnormalize(target, mean, std)
    # if len(target) != 0:
    #     target = target * std + mean
    return output, unnormalized_target
Example #12
0
def visualize(args, epoch, model, inference, data_loader, writer):
    def save_image(image, tag):
        image -= image.min()
        image /= image.max()
        grid = torchvision.utils.make_grid(image, nrow=4, pad_value=1)
        writer.add_image(tag, grid, epoch)

    def overlay_uncertainty(image, uncertainty, tag):
        image -= image.min()
        image /= image.max()
        uncertainty -= uncertainty.min()
        uncertainty /= uncertainty.max()
        # Convert to RGB
        image = image.expand(-1, 3, -1, -1)
        uncertainty = torch.cat(
            [torch.zeros_like(uncertainty), uncertainty, uncertainty], 1)
        # Overlay
        image = image * 0.7 + uncertainty * 0.3
        grid = torchvision.utils.make_grid(image, nrow=4, pad_value=1)
        writer.add_image(tag, grid, epoch)

    model.eval()
    with torch.no_grad():
        for iter, data in enumerate(data_loader):
            output, target, _, sigmas = wrap_confidence(
                inference(model, data, device=args.device))
            # HACK to make images look good in tensorboard
            output, mean_o, std_o = transforms.normalize_instance(output)
            output = output.clamp(-6, 6)
            output = transforms.unnormalize(output, mean_o, std_o)
            target, mean_t, std_t = transforms.normalize_instance(target)
            target = target.clamp(-6, 6)
            target = transforms.unnormalize(target, mean_t, std_t)

            output = output.unsqueeze(
                1)  # [batch_sz, h, w] --> [batch_sz, 1, h, w]
            target = target.unsqueeze(
                1)  # [batch_sz, h, w] --> [batch_sz, 1, h, w]
            error = torch.abs(target - output)

            if sigmas is not None:
                sigmas, mean_s, std_s = transforms.normalize_instance(sigmas)
                sigmas = sigmas.clamp(-6, 6)
                sigmas = transforms.unnormalize(sigmas, mean_s, std_s)
                sigmas = sigmas.unsqueeze(
                    1)  # [batch_sz, h, w] --> [batch_sz, 1, h, w]

            if isinstance(output, dict):
                for k, output_val in output.items():
                    # save_image(input, 'Input_{}'.format(k))
                    save_image(target, 'Target_{}'.format(k))
                    save_image(output, 'Reconstruction_{}'.format(k))
                    save_image(error, 'Error_{}'.format(k))
                    save_image(sigmas, 'Std_{}'.format(k))
                    if sigmas is not None:
                        overlay_uncertainty(error, sigmas,
                                            'Overlay_Error_Std_{}'.format(k))
                        overlay_uncertainty(
                            output, sigmas,
                            'Overlay_Reconstruction_Std_{}'.format(k))
            else:
                # save_image(input, 'Input')
                save_image(target, 'Target')
                save_image(output, 'Reconstruction')
                save_image(error, 'Error')
                if sigmas is not None:
                    save_image(sigmas, 'Std')
                    overlay_uncertainty(error, sigmas, 'Overlay_Error_Std')
                    overlay_uncertainty(output, sigmas,
                                        'Overlay_Reconstruction_Std')
            break