Пример #1
0
def test_symb_block_import_backward_compatible(block_type) -> None:
    x1 = mx.nd.array([1, 2, 3])
    x2 = [mx.nd.array([1, 5, 5]), mx.nd.array([2, 3, 3])]

    my_block = block_type()
    my_block.collect_params().initialize()
    my_block.hybridize()
    my_block(x1, x2)

    with tempfile.TemporaryDirectory(
        prefix="gluonts-estimator-temp-"
    ) as temp_dir:
        temp_path = Path(temp_dir)

        export_symb_block(my_block, temp_path, "gluonts-model")

        format_json_path = temp_path / "gluonts-model-in_out_format.json"

        assert format_json_path.exists()
        try:
            format_json_path.unlink()
            import_symb_block(3, temp_path, "gluonts-model")
        except FileNotFoundError:
            pytest.fail(
                "Symbol block import fails when format json is not in path"
            )
Пример #2
0
    def deserialize(
        cls, path: Path, ctx: Optional[mx.Context] = None
    ) -> "SymbolBlockPredictor":
        ctx = ctx if ctx is not None else get_mxnet_context()

        with mx.Context(ctx):
            # deserialize constructor parameters
            with (path / "parameters.json").open("r") as fp:
                parameters = load_json(fp.read())

            parameters["ctx"] = ctx

            # deserialize transformation chain
            with (path / "input_transform.json").open("r") as fp:
                transform = load_json(fp.read())

            # deserialize prediction network
            num_inputs = len(parameters["input_names"])
            prediction_net = import_symb_block(
                num_inputs, path, "prediction_net"
            )

            return SymbolBlockPredictor(
                input_transform=transform,
                prediction_net=prediction_net,
                **parameters,
            )