def verify_net_in_out( net_id: Optional[str] = None, meta_file: Optional[Union[str, Sequence[str]]] = None, config_file: Optional[Union[str, Sequence[str]]] = None, device: Optional[str] = None, p: Optional[int] = None, n: Optional[int] = None, any: Optional[int] = None, args_file: Optional[str] = None, **override, ): """ Verify the input and output data shape and data type of network defined in the metadata. Will test with fake Tensor data according to the required data shape in `metadata`. Typical usage examples: .. code-block:: bash python -m monai.bundle verify_net_in_out network --meta_file <meta path> --config_file <config path> Args: net_id: ID name of the network component to verify, it must be `torch.nn.Module`. meta_file: filepath of the metadata file to get network args, if `None`, must be provided in `args_file`. if it is a list of file paths, the content of them will be merged. config_file: filepath of the config file to get network definition, if `None`, must be provided in `args_file`. if it is a list of file paths, the content of them will be merged. device: target device to run the network forward computation, if None, prefer to "cuda" if existing. p: power factor to generate fake data shape if dim of expected shape is "x**p", default to 1. n: multiply factor to generate fake data shape if dim of expected shape is "x*n", default to 1. any: specified size to generate fake data shape if dim of expected shape is "*", default to 1. args_file: a JSON or YAML file to provide default values for `net_id`, `meta_file`, `config_file`, `device`, `p`, `n`, `any`, and override pairs. so that the command line inputs can be simplified. override: id-value pairs to override or add the corresponding config content. e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``. """ _args = _update_args( args=args_file, net_id=net_id, meta_file=meta_file, config_file=config_file, device=device, p=p, n=n, any=any, **override, ) _log_input_summary(tag="verify_net_in_out", args=_args) config_file_, meta_file_, net_id_, device_, p_, n_, any_ = _pop_args( _args, "config_file", "meta_file", net_id="", device="cuda:0" if is_available() else "cpu", p=1, n=1, any=1) parser = ConfigParser() parser.read_config(f=config_file_) parser.read_meta(f=meta_file_) # the rest key-values in the _args are to override config content for k, v in _args.items(): parser[k] = v try: key: str = net_id_ # mark the full id when KeyError net = parser.get_parsed_content(key).to(device_) key = "_meta_#network_data_format#inputs#image#num_channels" input_channels = parser[key] key = "_meta_#network_data_format#inputs#image#spatial_shape" input_spatial_shape = tuple(parser[key]) key = "_meta_#network_data_format#inputs#image#dtype" input_dtype = get_equivalent_dtype(parser[key], torch.Tensor) key = "_meta_#network_data_format#outputs#pred#num_channels" output_channels = parser[key] key = "_meta_#network_data_format#outputs#pred#dtype" output_dtype = get_equivalent_dtype(parser[key], torch.Tensor) except KeyError as e: raise KeyError( f"Failed to verify due to missing expected key in the config: {key}." ) from e net.eval() with torch.no_grad(): spatial_shape = _get_fake_spatial_shape(input_spatial_shape, p=p_, n=n_, any=any_) # type: ignore test_data = torch.rand(*(1, input_channels, *spatial_shape), dtype=input_dtype, device=device_) output = net(test_data) if output.shape[1] != output_channels: raise ValueError( f"output channel number `{output.shape[1]}` doesn't match: `{output_channels}`." ) if output.dtype != output_dtype: raise ValueError( f"dtype of output data `{output.dtype}` doesn't match: {output_dtype}." ) logger.info("data shape of network is verified with no error.")
def test_value(self, argments, image, expected_data): result = HistogramNormalized(**argments)(image)["img"] assert_allclose(result, expected_data) self.assertEqual( get_equivalent_dtype(result.dtype, data_type=np.ndarray), argments.get("dtype", np.float32))