Пример #1
0
def export_policy_model(
    model_path: str,
    output_filepath: str,
    behavior_name: str,
    graph: tf.Graph,
    sess: tf.Session,
) -> None:
    """
    Exports a TF graph for a Policy to .nn and/or .onnx format for Unity embedding.

    :param output_filepath: file path to output the model (without file suffix)
    :param behavior_name: behavior name of the trained model
    :param graph: Tensorflow Graph for the policy
    :param sess: Tensorflow session for the policy
    """
    frozen_graph_def = _make_frozen_graph(behavior_name, graph, sess)
    if not os.path.exists(output_filepath):
        os.makedirs(output_filepath)
    # Save frozen graph
    frozen_graph_def_path = model_path + "/frozen_graph_def.pb"
    with gfile.GFile(frozen_graph_def_path, "wb") as f:
        f.write(frozen_graph_def.SerializeToString())

    # Convert to barracuda
    if SerializationSettings.convert_to_barracuda:
        tf2bc.convert(frozen_graph_def_path, f"{output_filepath}.nn")
        logger.info(f"Exported {output_filepath}.nn")

    # Save to onnx too (if we were able to import it)
    if ONNX_EXPORT_ENABLED:
        if SerializationSettings.convert_to_onnx:
            try:
                onnx_graph = convert_frozen_to_onnx(behavior_name, frozen_graph_def)
                onnx_output_path = f"{output_filepath}.onnx"
                with open(onnx_output_path, "wb") as f:
                    f.write(onnx_graph.SerializeToString())
                logger.info(f"Converting to {onnx_output_path}")
            except Exception:
                # Make conversion errors fatal depending on environment variables (only done during CI)
                if _enforce_onnx_conversion():
                    raise
                logger.exception(
                    "Exception trying to save ONNX graph. Please report this error on "
                    "https://github.com/Unity-Technologies/ml-agents/issues and "
                    "attach a copy of frozen_graph_def.pb"
                )

    else:
        if _enforce_onnx_conversion():
            raise RuntimeError(
                "ONNX conversion enforced, but couldn't import dependencies."
            )
Пример #2
0
def test_barracuda_converter():
    path_prefix = os.path.dirname(os.path.abspath(__file__))
    tmpfile = os.path.join(tempfile._get_default_tempdir(),
                           next(tempfile._get_candidate_names()) + ".nn")

    # make sure there are no left-over files
    if os.path.isfile(tmpfile):
        os.remove(tmpfile)

    tf2bc.convert(path_prefix + "/BasicLearning.pb", tmpfile)

    # test if file exists after conversion
    assert os.path.isfile(tmpfile)
    # currently converter produces small output file even if input file is empty
    # 100 bytes is high enough to prove that conversion was successful
    assert os.path.getsize(tmpfile) > 100

    # cleanup
    os.remove(tmpfile)