def test_symb_block_import_backward_compatible(block_type) -> None: x1 = mx.nd.array([1, 2, 3]) x2 = [mx.nd.array([1, 5, 5]), mx.nd.array([2, 3, 3])] my_block = block_type() my_block.collect_params().initialize() my_block.hybridize() my_block(x1, x2) with tempfile.TemporaryDirectory( prefix="gluonts-estimator-temp-" ) as temp_dir: temp_path = Path(temp_dir) export_symb_block(my_block, temp_path, "gluonts-model") format_json_path = temp_path / "gluonts-model-in_out_format.json" assert format_json_path.exists() try: format_json_path.unlink() import_symb_block(3, temp_path, "gluonts-model") except FileNotFoundError: pytest.fail( "Symbol block import fails when format json is not in path" )
def serialize_prediction_net(self, path: Path) -> None: export_symb_block(self.prediction_net, path, "prediction_net")