コード例 #1
0
def export_symb_block(hb: mx.gluon.HybridBlock,
                      model_dir: Path,
                      model_name: str,
                      epoch: int = 0) -> None:
    """
    Serializes a hybridized Gluon `HybridBlock`.

    Parameters
    ----------
    hb
        The block to export.
    model_dir
        The path where the model will be saved.
    model_name
        The name identifying the model.
    epoch
        The epoch number, which together with the `model_name` identifies the
        model parameters.
    """
    hb.export(path=str(model_dir / model_name), epoch=epoch)
コード例 #2
0
def export_repr_block(
    rb: mx.gluon.HybridBlock, model_dir: Path, model_name: str, epoch: int = 0
) -> None:
    """
    Serializes a representable Gluon block.

    Parameters
    ----------
    rb
        The block to export.
    model_dir
        The path where the model will be saved.
    model_name
        The name identifying the model.
    epoch
        The epoch number, which together with the `model_name` identifies the
        model parameters.
    """
    with (model_dir / f"{model_name}-network.json").open("w") as fp:
        print(dump_json(rb), file=fp)
    rb.save_parameters(str(model_dir / f"{model_name}-{epoch:04}.params"))
コード例 #3
0
def export_symb_block(hb: mx.gluon.HybridBlock,
                      model_dir: Path,
                      model_name: str,
                      epoch: int = 0) -> None:
    """
    Serializes a hybridized Gluon `HybridBlock`.

    Parameters
    ----------
    hb
        The block to export.
    model_dir
        The path where the model will be saved.
    model_name
        The name identifying the model.
    epoch
        The epoch number, which together with the `model_name` identifies the
        model parameters.
    """
    hb.export(path=str(model_dir / model_name), epoch=epoch)
    with (model_dir / f"{model_name}-in_out_format.json").open("w") as fp:
        in_out_format = dict(in_format=hb._in_format,
                             out_format=hb._out_format)
        print(dump_json(in_out_format), file=fp)
コード例 #4
0
def equals_representable_block(this: mx.gluon.HybridBlock,
                               that: mx.gluon.HybridBlock) -> bool:
    """
    Structural equality check between two :class:`~mxnet.gluon.HybridBlock`
    objects with :func:`validated` initializers.

    Two blocks ``this`` and ``that`` are considered *structurally equal* if all
    the conditions of :func:`equals` are met, and in addition their parameter
    dictionaries obtained with
    :func:`~mxnet.gluon.block.Block.collect_params` are also structurally
    equal.

    Specializes :func:`equals` for invocations where the first parameter is an
    instance of the :class:`~mxnet.gluon.HybridBlock` class.

    Parameters
    ----------
    this, that
        Objects to compare.

    Returns
    -------
    bool
        A boolean value indicating whether ``this`` and ``that`` are
        structurally equal.

    See Also
    --------
    equals
        Dispatching function.
    equals_parameter_dict
        Specialization of :func:`equals` for Gluon
        :class:`~mxnet.gluon.ParameterDict` input arguments.
    """
    if not equals_default_impl(this, that):
        return False

    if not equals_parameter_dict(this.collect_params(), that.collect_params()):
        return False

    return True
コード例 #5
0
def objective_function(model: mx.gluon.HybridBlock,
                       training_data_iterator: mx.io.NDArrayIter,
                       loss: mx.gluon.loss.Loss,
                       gamma=AdaNetConfig.GAMMA.value) -> nd.array:
    """
    :param model: Union[SuperCandidateHull, ModelTemplate]
    :param training_data_iterator:
    :param loss:
    :param gamma:
    :return:
    """
    training_data_iterator.reset()
    err_list = []
    for batch_i, batch in enumerate(training_data_iterator):
        pred = model(batch.data[0])[0][0]
        label = batch.label[0]
        error = loss(pred, label)
        err_list.append(error)
    err = concatenate(err_list)
    c_complexities = model.get_candidate_complexity()
    c_complexities = c_complexities * gamma
    objective = err.mean() + c_complexities.mean()

    return objective[0][0]
コード例 #6
0
ファイル: utils.py プロジェクト: slyforce/MusicStyleTransfer
def load_model_parameters(model: mx.gluon.HybridBlock, path: str,
                          context: mx.Context):
    model.load_parameters(path, ctx=context)
コード例 #7
0
ファイル: utils.py プロジェクト: slyforce/MusicStyleTransfer
def save_model(model: mx.gluon.HybridBlock, output_path: str):
    model.save_parameters(output_path)