def export_policy_model( model_path: str, output_filepath: str, behavior_name: str, graph: tf.Graph, sess: tf.Session, ) -> None: """ Exports a TF graph for a Policy to .nn and/or .onnx format for Unity embedding. :param output_filepath: file path to output the model (without file suffix) :param behavior_name: behavior name of the trained model :param graph: Tensorflow Graph for the policy :param sess: Tensorflow session for the policy """ frozen_graph_def = _make_frozen_graph(behavior_name, graph, sess) if not os.path.exists(output_filepath): os.makedirs(output_filepath) # Save frozen graph frozen_graph_def_path = model_path + "/frozen_graph_def.pb" with gfile.GFile(frozen_graph_def_path, "wb") as f: f.write(frozen_graph_def.SerializeToString()) # Convert to barracuda if SerializationSettings.convert_to_barracuda: tf2bc.convert(frozen_graph_def_path, f"{output_filepath}.nn") logger.info(f"Exported {output_filepath}.nn") # Save to onnx too (if we were able to import it) if ONNX_EXPORT_ENABLED: if SerializationSettings.convert_to_onnx: try: onnx_graph = convert_frozen_to_onnx(behavior_name, frozen_graph_def) onnx_output_path = f"{output_filepath}.onnx" with open(onnx_output_path, "wb") as f: f.write(onnx_graph.SerializeToString()) logger.info(f"Converting to {onnx_output_path}") except Exception: # Make conversion errors fatal depending on environment variables (only done during CI) if _enforce_onnx_conversion(): raise logger.exception( "Exception trying to save ONNX graph. Please report this error on " "https://github.com/Unity-Technologies/ml-agents/issues and " "attach a copy of frozen_graph_def.pb" ) else: if _enforce_onnx_conversion(): raise RuntimeError( "ONNX conversion enforced, but couldn't import dependencies." )
def test_barracuda_converter(): path_prefix = os.path.dirname(os.path.abspath(__file__)) tmpfile = os.path.join(tempfile._get_default_tempdir(), next(tempfile._get_candidate_names()) + ".nn") # make sure there are no left-over files if os.path.isfile(tmpfile): os.remove(tmpfile) tf2bc.convert(path_prefix + "/BasicLearning.pb", tmpfile) # test if file exists after conversion assert os.path.isfile(tmpfile) # currently converter produces small output file even if input file is empty # 100 bytes is high enough to prove that conversion was successful assert os.path.getsize(tmpfile) > 100 # cleanup os.remove(tmpfile)