コード例 #1
0
ファイル: reader.py プロジェクト: m-novikov/tiktorch
def eval_model_zip(model_zip: ZipFile,
                   devices: Sequence[str],
                   cache_path: Optional[Path] = None) -> ModelAdapter:
    temp_path = Path(tempfile.mkdtemp(prefix="tiktorch_"))
    if cache_path is None:
        cache_path = temp_path / "cache"

    model_zip.extractall(temp_path)

    spec_file_str = guess_model_path(
        [str(file_name) for file_name in temp_path.glob("*")])
    if not spec_file_str:
        raise Exception(
            "Model config file not found, make sure that .model.yaml file in the root of your model archive"
        )

    pybio_model = spec.load_and_resolve_spec(spec_file_str)
    ret = create_model_adapter(pybio_model=pybio_model, devices=devices)

    def _on_error(function, path, exc_info):
        logger.warning("Failed to delete temp directory %s", path)

    shutil.rmtree(temp_path, onerror=_on_error)

    return ret
コード例 #2
0
def test_load_specs_from_manifest(cache_path, category, spec_path):

    spec_path = MANIFEST_PATH.parent / spec_path
    assert spec_path.exists()

    loaded_spec = load_and_resolve_spec(str(spec_path))
    instance = utils.get_instance(loaded_spec)
    assert instance
コード例 #3
0
def test_UNet2dNucleiBroads_load_weights():
    spec_path = (
        Path(__file__).parent /
        "../../../specs/models/unet2d_nuclei_broad/UNet2DNucleiBroad.model.yaml"
    ).resolve()
    assert spec_path.exists(), spec_path
    model_spec = load_and_resolve_spec(spec_path)
    assert isinstance(model_spec.weights["pytorch_state_dict"].source, Path)
コード例 #4
0
def _load_from_zip(model_zip: ZipFile):
    temp_path = Path(tempfile.mkdtemp(prefix="tiktorch_"))
    cache_path = temp_path / "cache"

    model_zip.extractall(temp_path)

    spec_file_str = guess_model_path(
        [str(file_name) for file_name in temp_path.glob("*")])
    if not spec_file_str:
        raise Exception(
            "Model config file not found, make sure that .model.yaml file in the root of your model archive"
        )
    return spec.load_and_resolve_spec(spec_file_str), cache_path
コード例 #5
0
def main():
    args = parser.parse_args()
    # try opening model from model.zip
    try:
        with ZipFile(args.model, "r") as model_zip:
            pybio_model, cache_path = _load_from_zip(model_zip)
    # otherwise open from model.yaml
    except BadZipFile:
        spec_path = os.path.abspath(args.model)
        pybio_model = spec.load_and_resolve_spec(spec_path)
        cache_path = None

    model = create_prediction_pipeline(pybio_model=pybio_model,
                                       devices=["cpu"],
                                       weight_format=args.weight_format,
                                       preserve_batch_dim=True)

    input_args = [
        load_data(inp, inp_spec)
        for inp, inp_spec in zip(pybio_model.test_inputs, pybio_model.inputs)
    ]
    expected_outputs = [
        load_data(out, out_spec)
        for out, out_spec in zip(pybio_model.test_outputs, pybio_model.outputs)
    ]

    results = [model.forward(*input_args)]

    for res, exp in zip(results, expected_outputs):
        assert_array_almost_equal(exp, res, args.decimals)

    if cache_path is not None:

        def _on_error(function, path, exc_info):
            warnings.warn("Failed to delete temp directory %s", path)

        shutil.rmtree(cache_path, onerror=_on_error)

    print("All results match the expected output")
    return 0