def convert_weights_to_pytorch_script(model_spec: Union[str, Path, spec.raw_nodes.Model], output_path: Union[str, Path], use_tracing: bool = True): """ Convert model weights from format 'pytorch_state_dict' to 'torchscript'. """ if isinstance(model_spec, (str, Path)): # TODO we probably need the root path here model_spec = spec.load_model(model_spec) with torch.no_grad(): # load input and expected output data input_data = np.load(model_spec.test_inputs[0]).astype('float32') input_data = torch.from_numpy(input_data) # instantiate model and get reference output model = load_model(model_spec) # make scripted model if use_tracing: scripted_model = torch.jit.trace(model, input_data) else: scripted_model = torch.jit.script(model) # check the scripted model ret = _check_predictions(model, scripted_model, model_spec, input_data) # save the torchscript model scripted_model.save(output_path) return ret
def eval_model_zip(model_zip: ZipFile): with TemporaryDirectory() as tempdir: temp_path = Path(tempdir) model_zip.extractall(temp_path) spec_file_str = guess_model_path( [str(file_name) for file_name in temp_path.glob("*")]) bioimageio_model = load_model(spec_file_str) return get_nn_instance(bioimageio_model)
def _load_from_zip(model_zip: ZipFile): temp_path = Path(tempfile.mkdtemp(prefix="tiktorch_")) cache_path = temp_path / "cache" model_zip.extractall(temp_path) spec_file_str = guess_model_path( [str(file_name) for file_name in temp_path.glob("*")]) if not spec_file_str: raise Exception( "Model config file not found, make sure that .model.yaml file in the root of your model archive" ) return spec.load_model(spec_file_str), cache_path
def convert_weights_to_onnx( model_spec: Union[str, Path, spec.raw_nodes.Model], output_path: Union[str, Path], opset_version: Union[str, None] = 12, use_tracing: bool = True, verbose: bool = True ): """ Convert model weights from format 'pytorch_state_dict' to 'onnx'. Arguments: model_yaml: location of the model.yaml file with bioimage.io spec output_path: where to save the onnx weights opset_version: onnx opset version use_tracing: whether to use tracing or scripting to export the onnx format verbose: be verbose during the onnx export """ if rt is None: raise RuntimeError("Could not find onnxruntime.") if isinstance(model_spec, (str, Path)): root = os.path.split(model_spec)[0] model_spec = spec.load_model(Path(model_spec), root_path=root) with torch.no_grad(): # load input and expected output data input_data = np.load(model_spec.test_inputs[0]).astype('float32') input_tensor = torch.from_numpy(input_data) # instantiate and generate the expected output model = load_model(model_spec) expected_output = model(input_tensor).numpy() if use_tracing: torch.onnx.export(model, input_tensor, output_path, verbose=verbose, opset_version=opset_version) else: raise NotImplementedError # check the onnx model sess = rt.InferenceSession(output_path) input_name = sess.get_inputs()[0].name output = sess.run(None, {input_name: input_data})[0] try: assert_array_almost_equal(expected_output, output, decimal=4) return 0 except AssertionError as e: msg = f"The onnx weights were exported, but results before and after conversion do not agree:\n {str(e)}" warnings.warn(msg) return 1
def main(): args = parser.parse_args() # try opening model from model.zip try: with ZipFile(args.model, "r") as model_zip: bioimageio_model, cache_path = _load_from_zip(model_zip) # otherwise open from model.yaml except BadZipFile: spec_path = os.path.abspath(args.model) bioimageio_model = spec.load_model(spec_path) cache_path = None model = create_prediction_pipeline(bioimageio_model=bioimageio_model, devices=["cpu"], weight_format=args.weight_format, preserve_batch_dim=True) input_args = [ load_data(inp, inp_spec) for inp, inp_spec in zip( bioimageio_model.test_inputs, bioimageio_model.inputs) ] expected_outputs = [ load_data(out, out_spec) for out, out_spec in zip( bioimageio_model.test_outputs, bioimageio_model.outputs) ] results = [model.forward(*input_args)] for res, exp in zip(results, expected_outputs): assert_array_almost_equal(exp, res, args.decimals) if cache_path is not None: def _on_error(function, path, exc_info): warnings.warn("Failed to delete temp directory %s", path) shutil.rmtree(cache_path, onerror=_on_error) print("All results match the expected output") return 0