def test_vsnet(shape, cfg, center_fractions, accelerations): """ Test VSNet with different parameters Args: shape: shape of the input cfg: configuration of the model center_fractions: center fractions accelerations: accelerations Returns: None """ mask_func = RandomMaskFunc(center_fractions, accelerations) x = create_input(shape) outputs, masks = [], [] for i in range(x.shape[0]): output, mask, _ = transforms.apply_mask(x[i : i + 1], mask_func, seed=123) outputs.append(output) masks.append(mask) output = torch.cat(outputs) mask = torch.cat(masks) cfg = OmegaConf.create(cfg) cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True)) vsnet = VSNet(cfg) with torch.no_grad(): y = vsnet.forward(output, output, mask, output, target=torch.abs(torch.view_as_complex(output))) if y.shape[1:] != x.shape[2:4]: raise AssertionError
def test_xpdnet(shape, cfg, center_fractions, accelerations): """ Test the XPDNet model. Args: shape (): The shape of the input data. cfg (): The configuration of the model. center_fractions (): The center fractions of the subsampling. accelerations (): The accelerations of the subsampling. Returns: None. """ mask_func = RandomMaskFunc(center_fractions, accelerations) x = create_input(shape) outputs, masks = [], [] for i in range(x.shape[0]): output, mask, _ = transforms.apply_mask(x[i:i + 1], mask_func, seed=123) outputs.append(output) masks.append(mask) output = torch.cat(outputs) mask = torch.cat(masks) cfg = OmegaConf.create(cfg) cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True)) xpdnet = XPDNet(cfg) with torch.no_grad(): y = xpdnet.forward(output, output, mask, output, target=torch.abs(torch.view_as_complex(output))) if y.shape[1:] != x.shape[2:4]: raise AssertionError