def run_inference(checkpoint, data_path, output_path):
    varnet = VarNet()
    load_state_dict = torch.load(checkpoint)["state_dict"]
    state_dict = {}
    for k, v in load_state_dict.items():
        if "varnet" in k:
            state_dict[k[len("varnet."):]] = v

    varnet.load_state_dict(state_dict)
    varnet = varnet.eval()

    data_transform = DataTransform()

    dataset = SliceDataset(
        root=data_path,
        transform=data_transform,
        challenge="multicoil",
    )
    dataloader = torch.utils.data.DataLoader(dataset, num_workers=4)

    start_time = time.perf_counter()
    outputs = defaultdict(list)

    for batch in tqdm(dataloader, desc="Running inference..."):
        masked_kspace, mask, _, fname, slice_num, _, crop_size = batch
        crop_size = crop_size[0]  # always have a batch size of 1 for varnet
        fname = fname[0]  # always have batch size of 1 for varnet

        with torch.no_grad():
            try:
                device = torch.device("cuda")

                output = run_model(masked_kspace, mask, varnet, fname, device)
            except RuntimeError:
                print("running on cpu")
                device = torch.device("cpu")

                output = run_model(masked_kspace, mask, varnet, fname, device)

            output = T.center_crop(output, crop_size)[0]

        outputs[fname].append((slice_num, output))

    for fname in outputs:
        outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])])

    fastmri.save_reconstructions(outputs, output_path / "reconstructions")

    end_time = time.perf_counter()

    print(f"elapsed time for {len(dataloader)} slices: {end_time-start_time}")
Ejemplo n.º 2
0
def test_varnet_num_sense_lines(shape, chans, center_fractions, accelerations,
                                mask_center):
    mask_func = RandomMaskFunc(center_fractions, accelerations)
    x = create_input(shape)
    output, mask, num_low_freqs = transforms.apply_mask(x, mask_func, seed=123)

    varnet = VarNet(
        num_cascades=2,
        sens_chans=4,
        sens_pools=2,
        chans=chans,
        pools=2,
        mask_center=mask_center,
    )

    if mask_center is True:
        pad, net_low_freqs = varnet.sens_net.get_pad_and_num_low_freqs(
            mask, num_low_freqs)
        assert net_low_freqs == num_low_freqs
        assert torch.allclose(
            mask.squeeze()[int(pad):int(pad + net_low_freqs)].to(torch.int8),
            torch.ones([int(net_low_freqs)], dtype=torch.int8),
        )

    y = varnet(output, mask.byte(), num_low_frequencies=4)

    assert y.shape[1:] == x.shape[2:4]
Ejemplo n.º 3
0
def test_varnet(shape, chans, center_fractions, accelerations, mask_center):
    mask_func = RandomMaskFunc(center_fractions, accelerations)
    x = create_input(shape)
    outputs, masks = [], []
    for i in range(x.shape[0]):
        output, mask, _ = transforms.apply_mask(x[i:i + 1],
                                                mask_func,
                                                seed=123)
        outputs.append(output)
        masks.append(mask)

    output = torch.cat(outputs)
    mask = torch.cat(masks)

    varnet = VarNet(
        num_cascades=2,
        sens_chans=4,
        sens_pools=2,
        chans=chans,
        pools=2,
        mask_center=mask_center,
    )

    y = varnet(output, mask.byte())

    assert y.shape[1:] == x.shape[2:4]
Ejemplo n.º 4
0
    def __init__(
        self,
        num_cascades: int = 12,
        pools: int = 4,
        chans: int = 18,
        sens_pools: int = 4,
        sens_chans: int = 8,
        lr: float = 0.0003,
        lr_step_size: int = 40,
        lr_gamma: float = 0.1,
        weight_decay: float = 0.0,
        **kwargs,
    ):
        """
        Args:
            num_cascades: Number of cascades (i.e., layers) for variational
                network.
            pools: Number of downsampling and upsampling layers for cascade
                U-Net.
            chans: Number of channels for cascade U-Net.
            sens_pools: Number of downsampling and upsampling layers for
                sensitivity map U-Net.
            sens_chans: Number of channels for sensitivity map U-Net.
            lr: Learning rate.
            lr_step_size: Learning rate step size.
            lr_gamma: Learning rate gamma decay.
            weight_decay: Parameter for penalizing weights norm.
            num_sense_lines: Number of low-frequency lines to use for sensitivity map
                computation, must be even or `None`. Default `None` will automatically
                compute the number from masks. Default behaviour may cause some slices to
                use more low-frequency lines than others, when used in conjunction with
                e.g. the EquispacedMaskFunc defaults. To prevent this, either set
                `num_sense_lines`, or set `skip_low_freqs` and `skip_around_low_freqs`
                to `True` in the EquispacedMaskFunc. Note that setting this value may
                lead to undesired behaviour when training on multiple accelerations
                simultaneously.
        """
        super().__init__(**kwargs)
        self.save_hyperparameters()

        self.num_cascades = num_cascades
        self.pools = pools
        self.chans = chans
        self.sens_pools = sens_pools
        self.sens_chans = sens_chans
        self.lr = lr
        self.lr_step_size = lr_step_size
        self.lr_gamma = lr_gamma
        self.weight_decay = weight_decay

        self.varnet = VarNet(
            num_cascades=self.num_cascades,
            sens_chans=self.sens_chans,
            sens_pools=self.sens_pools,
            chans=self.chans,
            pools=self.pools,
        )

        self.loss = fastmri.SSIMLoss()
Ejemplo n.º 5
0
def test_varnet_scripting():
    model = VarNet(num_cascades=4,
                   pools=2,
                   chans=8,
                   sens_pools=2,
                   sens_chans=4)
    scr = torch.jit.script(model)
    assert scr is not None
Ejemplo n.º 6
0
def test_varnet(shape, out_chans, chans, center_fractions, accelerations):
    mask_func = RandomMaskFunc(center_fractions, accelerations)
    x = create_input(shape)
    output, mask = transforms.apply_mask(x, mask_func, seed=123)

    varnet = VarNet(num_cascades=2,
                    sens_chans=4,
                    sens_pools=2,
                    chans=4,
                    pools=2)

    y = varnet(output, mask.byte())

    assert y.shape[1:] == x.shape[2:4]
Ejemplo n.º 7
0
    def __init__(
        self,
        num_cascades: int = 12,
        pools: int = 4,
        chans: int = 18,
        sens_pools: int = 4,
        sens_chans: int = 8,
        lr: float = 0.0003,
        lr_step_size: int = 40,
        lr_gamma: float = 0.1,
        weight_decay: float = 0.0,
        **kwargs,
    ):
        """
        Args:
            num_cascades: Number of cascades (i.e., layers) for variational
                network.
            pools: Number of downsampling and upsampling layers for cascade
                U-Net.
            chans: Number of channels for cascade U-Net.
            sens_pools: Number of downsampling and upsampling layers for
                sensitivity map U-Net.
            sens_chans: Number of channels for sensitivity map U-Net.
            lr: Learning rate.
            lr_step_size: Learning rate step size.
            lr_gamma: Learning rate gamma decay.
            weight_decay: Parameter for penalizing weights norm.
        """
        super().__init__(**kwargs)
        self.save_hyperparameters()

        self.num_cascades = num_cascades
        self.pools = pools
        self.chans = chans
        self.sens_pools = sens_pools
        self.sens_chans = sens_chans
        self.lr = lr
        self.lr_step_size = lr_step_size
        self.lr_gamma = lr_gamma
        self.weight_decay = weight_decay

        self.varnet = VarNet(
            num_cascades=self.num_cascades,
            sens_chans=self.sens_chans,
            sens_pools=self.sens_pools,
            chans=self.chans,
            pools=self.pools,
        )

        self.loss = fastmri.SSIMLoss()
def run_inference(challenge, state_dict_file, data_path, output_path, device):
    model = VarNet(num_cascades=12,
                   pools=4,
                   chans=18,
                   sens_pools=4,
                   sens_chans=8)
    # download the state_dict if we don't have it
    if state_dict_file is None:
        if not Path(MODEL_FNAMES[challenge]).exists():
            url_root = VARNET_FOLDER
            download_model(url_root + MODEL_FNAMES[challenge],
                           MODEL_FNAMES[challenge])

        state_dict_file = MODEL_FNAMES[challenge]

    model.load_state_dict(torch.load(state_dict_file))
    model = model.eval()

    # data loader setup
    data_transform = T.VarNetDataTransform()
    dataset = SliceDataset(root=data_path,
                           transform=data_transform,
                           challenge="multicoil")
    dataloader = torch.utils.data.DataLoader(dataset, num_workers=4)

    # run the model
    start_time = time.perf_counter()
    outputs = defaultdict(list)
    model = model.to(device)

    for batch in tqdm(dataloader, desc="Running inference"):
        with torch.no_grad():
            output, slice_num, fname = run_varnet_model(batch, model, device)

        outputs[fname].append((slice_num, output))

    # save outputs
    for fname in outputs:
        outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])])

    fastmri.save_reconstructions(outputs, output_path / "reconstructions")

    end_time = time.perf_counter()

    print(f"Elapsed time for {len(dataloader)} slices: {end_time-start_time}")
Ejemplo n.º 9
0
    def __init__(
        self,
        num_cascades=12,
        pools=4,
        chans=18,
        sens_pools=4,
        sens_chans=8,
        mask_type="equispaced",
        center_fractions=[0.08],
        accelerations=[4],
        lr=0.0003,
        lr_step_size=40,
        lr_gamma=0.1,
        weight_decay=0.0,
        **kwargs,
    ):
        """
        Args:
            num_cascades (int, default=12): Number of cascades (i.e., layers)
                for variational network.
            sens_chans (int, default=8): Number of channels for sensitivity map
                U-Net.
            sens_pools (int, default=8): Number of downsampling and upsampling
                layers for sensitivity map U-Net.
            chans (int, default=18): Number of channels for cascade U-Net.
            pools (int, default=4): Number of downsampling and upsampling
                layers for cascade U-Net.
            mask_type (str, default="equispaced"): Type of mask from ("random",
                "equispaced").
            center_fractions (list, default=[0.08]): Fraction of all samples to
                take from center (i.e., list of floats).
            accelerations (list, default=[4]): List of accelerations to apply
                (i.e., list of ints).
            lr (float, default=0.0003): Learning rate.
            lr_step_size (int, default=40): Learning rate step size.
            lr_gamma (float, default=0.1): Learning rate gamma decay.
            weight_decay (float, default=0): Parameter for penalizing weights
                norm.
        """
        super().__init__(**kwargs)

        if self.batch_size != 1:
            raise NotImplementedError(
                f"Only batch_size=1 allowed for {self.__class__.__name__}"
            )

        self.num_cascades = num_cascades
        self.pools = pools
        self.chans = chans
        self.sens_pools = sens_pools
        self.sens_chans = sens_chans
        self.mask_type = mask_type
        self.center_fractions = center_fractions
        self.accelerations = accelerations
        self.lr = lr
        self.lr_step_size = lr_step_size
        self.lr_gamma = lr_gamma
        self.weight_decay = weight_decay

        self.varnet = VarNet(
            num_cascades=self.num_cascades,
            sens_chans=self.sens_chans,
            sens_pools=self.sens_pools,
            chans=self.chans,
            pools=self.pools,
        )

        self.loss = fastmri.SSIMLoss()