コード例 #1
0
ファイル: test_vsnet.py プロジェクト: wdika/mridc
def test_vsnet(shape, cfg, center_fractions, accelerations):
    """
    Test VSNet with different parameters

    Args:
        shape: shape of the input
        cfg: configuration of the model
        center_fractions: center fractions
        accelerations: accelerations

    Returns:
        None
    """
    mask_func = RandomMaskFunc(center_fractions, accelerations)
    x = create_input(shape)

    outputs, masks = [], []
    for i in range(x.shape[0]):
        output, mask, _ = transforms.apply_mask(x[i : i + 1], mask_func, seed=123)
        outputs.append(output)
        masks.append(mask)

    output = torch.cat(outputs)
    mask = torch.cat(masks)

    cfg = OmegaConf.create(cfg)
    cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True))

    vsnet = VSNet(cfg)

    with torch.no_grad():
        y = vsnet.forward(output, output, mask, output, target=torch.abs(torch.view_as_complex(output)))

    if y.shape[1:] != x.shape[2:4]:
        raise AssertionError
コード例 #2
0
def test_xpdnet(shape, cfg, center_fractions, accelerations):
    """
    Test the XPDNet model.

    Args:
        shape (): The shape of the input data.
        cfg (): The configuration of the model.
        center_fractions (): The center fractions of the subsampling.
        accelerations (): The accelerations of the subsampling.

    Returns:
        None.
    """
    mask_func = RandomMaskFunc(center_fractions, accelerations)
    x = create_input(shape)

    outputs, masks = [], []
    for i in range(x.shape[0]):
        output, mask, _ = transforms.apply_mask(x[i:i + 1],
                                                mask_func,
                                                seed=123)
        outputs.append(output)
        masks.append(mask)

    output = torch.cat(outputs)
    mask = torch.cat(masks)

    cfg = OmegaConf.create(cfg)
    cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True))

    xpdnet = XPDNet(cfg)

    with torch.no_grad():
        y = xpdnet.forward(output,
                           output,
                           mask,
                           output,
                           target=torch.abs(torch.view_as_complex(output)))

    if y.shape[1:] != x.shape[2:4]:
        raise AssertionError