Ejemplo n.º 1
0
def copy_parameters(
    net_source: mx.gluon.Block,
    net_dest: mx.gluon.Block,
    ignore_extra: bool = False,
    allow_missing: bool = False,
) -> None:
    """
    Copies parameters from one network to another.

    Parameters
    ----------
    net_source
        Input network.
    net_dest
        Output network.
    ignore_extra
        Whether to ignore parameters from the source that are not
        present in the target.
    allow_missing
        Whether to allow additional parameters in the target not
        present in the source.
    """
    with tempfile.TemporaryDirectory(
        prefix="gluonts-estimator-temp-"
    ) as model_dir:
        model_dir_path = str(Path(model_dir) / "tmp_model")
        net_source.save_parameters(model_dir_path)
        net_dest.load_parameters(
            model_dir_path,
            ctx=mx.current_context(),
            allow_missing=allow_missing,
            ignore_extra=ignore_extra,
        )
Ejemplo n.º 2
0
def copy_parameters(net_source: mx.gluon.Block,
                    net_dest: mx.gluon.Block) -> None:
    """
    Copies parameters from one network to another.

    Parameters
    ----------
    net_source
        Input network.
    net_dest
        Output network.
    ignore_extra
        Whether to silently ignore parameters from the source that are not
        present in the target.
    """
    with tempfile.TemporaryDirectory(
            prefix='gluonts-estimator-temp-') as model_dir:
        model_dir_path = str(Path(model_dir) / 'tmp_model')
        net_source.save_parameters(model_dir_path)
        net_dest.load_parameters(
            model_dir_path,
            ctx=mx.current_context(),
            allow_missing=False,
            ignore_extra=False,
        )