Example #1
0
def verify_net_in_out(
    net_id: Optional[str] = None,
    meta_file: Optional[Union[str, Sequence[str]]] = None,
    config_file: Optional[Union[str, Sequence[str]]] = None,
    device: Optional[str] = None,
    p: Optional[int] = None,
    n: Optional[int] = None,
    any: Optional[int] = None,
    args_file: Optional[str] = None,
    **override,
):
    """
    Verify the input and output data shape and data type of network defined in the metadata.
    Will test with fake Tensor data according to the required data shape in `metadata`.

    Typical usage examples:

    .. code-block:: bash

        python -m monai.bundle verify_net_in_out network --meta_file <meta path> --config_file <config path>

    Args:
        net_id: ID name of the network component to verify, it must be `torch.nn.Module`.
        meta_file: filepath of the metadata file to get network args, if `None`, must be provided in `args_file`.
            if it is a list of file paths, the content of them will be merged.
        config_file: filepath of the config file to get network definition, if `None`, must be provided in `args_file`.
            if it is a list of file paths, the content of them will be merged.
        device: target device to run the network forward computation, if None, prefer to "cuda" if existing.
        p: power factor to generate fake data shape if dim of expected shape is "x**p", default to 1.
        n: multiply factor to generate fake data shape if dim of expected shape is "x*n", default to 1.
        any: specified size to generate fake data shape if dim of expected shape is "*", default to 1.
        args_file: a JSON or YAML file to provide default values for `net_id`, `meta_file`, `config_file`,
            `device`, `p`, `n`, `any`, and override pairs. so that the command line inputs can be simplified.
        override: id-value pairs to override or add the corresponding config content.
            e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``.

    """

    _args = _update_args(
        args=args_file,
        net_id=net_id,
        meta_file=meta_file,
        config_file=config_file,
        device=device,
        p=p,
        n=n,
        any=any,
        **override,
    )
    _log_input_summary(tag="verify_net_in_out", args=_args)
    config_file_, meta_file_, net_id_, device_, p_, n_, any_ = _pop_args(
        _args,
        "config_file",
        "meta_file",
        net_id="",
        device="cuda:0" if is_available() else "cpu",
        p=1,
        n=1,
        any=1)

    parser = ConfigParser()
    parser.read_config(f=config_file_)
    parser.read_meta(f=meta_file_)

    # the rest key-values in the _args are to override config content
    for k, v in _args.items():
        parser[k] = v

    try:
        key: str = net_id_  # mark the full id when KeyError
        net = parser.get_parsed_content(key).to(device_)
        key = "_meta_#network_data_format#inputs#image#num_channels"
        input_channels = parser[key]
        key = "_meta_#network_data_format#inputs#image#spatial_shape"
        input_spatial_shape = tuple(parser[key])
        key = "_meta_#network_data_format#inputs#image#dtype"
        input_dtype = get_equivalent_dtype(parser[key], torch.Tensor)
        key = "_meta_#network_data_format#outputs#pred#num_channels"
        output_channels = parser[key]
        key = "_meta_#network_data_format#outputs#pred#dtype"
        output_dtype = get_equivalent_dtype(parser[key], torch.Tensor)
    except KeyError as e:
        raise KeyError(
            f"Failed to verify due to missing expected key in the config: {key}."
        ) from e

    net.eval()
    with torch.no_grad():
        spatial_shape = _get_fake_spatial_shape(input_spatial_shape,
                                                p=p_,
                                                n=n_,
                                                any=any_)  # type: ignore
        test_data = torch.rand(*(1, input_channels, *spatial_shape),
                               dtype=input_dtype,
                               device=device_)
        output = net(test_data)
        if output.shape[1] != output_channels:
            raise ValueError(
                f"output channel number `{output.shape[1]}` doesn't match: `{output_channels}`."
            )
        if output.dtype != output_dtype:
            raise ValueError(
                f"dtype of output data `{output.dtype}` doesn't match: {output_dtype}."
            )
    logger.info("data shape of network is verified with no error.")
 def test_value(self, argments, image, expected_data):
     result = HistogramNormalized(**argments)(image)["img"]
     assert_allclose(result, expected_data)
     self.assertEqual(
         get_equivalent_dtype(result.dtype, data_type=np.ndarray),
         argments.get("dtype", np.float32))