Ejemplo n.º 1
0
    def test_export(self, key_in_ckpt):
        meta_file = os.path.join(os.path.dirname(__file__), "testing_data",
                                 "metadata.json")
        config_file = os.path.join(os.path.dirname(__file__), "testing_data",
                                   "inference.json")
        with tempfile.TemporaryDirectory() as tempdir:
            def_args = {"meta_file": "will be replaced by `meta_file` arg"}
            def_args_file = os.path.join(tempdir, "def_args.json")
            ckpt_file = os.path.join(tempdir, "model.pt")
            ts_file = os.path.join(tempdir, "model.ts")

            parser = ConfigParser()
            parser.export_config_file(config=def_args, filepath=def_args_file)
            parser.read_config(config_file)
            net = parser.get_parsed_content("network_def")
            save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net},
                       path=ckpt_file)

            cmd = [
                "coverage", "run", "-m", "monai.bundle", "ckpt_export",
                "network_def", "--filepath", ts_file
            ]
            cmd += [
                "--meta_file", meta_file, "--config_file", config_file,
                "--ckpt_file", ckpt_file
            ]
            cmd += ["--key_in_ckpt", key_in_ckpt, "--args_file", def_args_file]
            subprocess.check_call(cmd)
            self.assertTrue(os.path.exists(ts_file))
Ejemplo n.º 2
0
 def test_file(self,
               src,
               expected_keys,
               create_dir=True,
               atomic=True,
               func=None,
               kwargs=None):
     with tempfile.TemporaryDirectory() as tempdir:
         path = os.path.join(tempdir, "test_ckpt.pt")
         if kwargs is None:
             kwargs = {}
         save_state(src=src,
                    path=path,
                    create_dir=create_dir,
                    atomic=atomic,
                    func=func,
                    **kwargs)
         ckpt = dict(torch.load(path))
         for k in ckpt.keys():
             self.assertIn(k, expected_keys)
Ejemplo n.º 3
0
def init_bundle(
    bundle_dir: PathLike,
    ckpt_file: Optional[PathLike] = None,
    network: Optional[torch.nn.Module] = None,
    metadata_str: Union[Dict, str] = DEFAULT_METADATA,
    inference_str: Union[Dict, str] = DEFAULT_INFERENCE,
):
    """
    Initialise a new bundle directory with some default configuration files and optionally network weights.

    Typical usage example:

    .. code-block:: bash

        python -m monai.bundle init_bundle /path/to/bundle_dir network_ckpt.pt

    Args:
        bundle_dir: directory name to create, must not exist but parent direct must exist
        ckpt_file: optional checkpoint file to copy into bundle
        network: if given instead of ckpt_file this network's weights will be stored in bundle
    """

    bundle_dir = Path(bundle_dir).absolute()

    if bundle_dir.exists():
        raise ValueError(f"Specified bundle directory '{str(bundle_dir)}' already exists")

    if not bundle_dir.parent.is_dir():
        raise ValueError(f"Parent directory of specified bundle directory '{str(bundle_dir)}' does not exist")

    configs_dir = bundle_dir / "configs"
    models_dir = bundle_dir / "models"
    docs_dir = bundle_dir / "docs"

    bundle_dir.mkdir()
    configs_dir.mkdir()
    models_dir.mkdir()
    docs_dir.mkdir()

    if isinstance(metadata_str, dict):
        metadata_str = json.dumps(metadata_str, indent=4)

    if isinstance(inference_str, dict):
        inference_str = json.dumps(inference_str, indent=4)

    with open(str(configs_dir / "metadata.json"), "w") as o:
        o.write(metadata_str)

    with open(str(configs_dir / "inference.json"), "w") as o:
        o.write(inference_str)

    with open(str(docs_dir / "README.md"), "w") as o:
        readme = """
        # Your Model Name

        Describe your model here and how to run it, for example using `inference.json`:

        ```
        python -m monai.bundle run evaluating \
            --meta_file /path/to/bundle/configs/metadata.json \
            --config_file /path/to/bundle/configs/inference.json \
            --dataset_dir ./input \
            --bundle_root /path/to/bundle
        ```
        """

        o.write(dedent(readme))

    with open(str(docs_dir / "license.txt"), "w") as o:
        o.write("Select a license and place its terms here\n")

    if ckpt_file is not None:
        copyfile(str(ckpt_file), str(models_dir / "model.pt"))
    elif network is not None:
        save_state(network, str(models_dir / "model.pt"))