コード例 #1
0
ファイル: converter.py プロジェクト: dallarosa/tfjs-to-tf
def convert(arguments):
    """
    Convert a TensorflowJS-model to a TensorFlow-model.

    Args:
        arguments: List of command-line arguments   
    """
    args = get_arg_parser().parse_args(arguments)
    if args.show_version:
        print("\ntfjs_graph_converter {}\n".format(version.VERSION))
        print("Dependency versions:")
        print("    tensorflow {}".format(tf.version.VERSION))
        print("    tensorflowjs {}".format(tfjs.__version__))
        return

    def info(message, end=None):
        if not args.silence:
            print(message, end=end)

    if not args.input_path:
        raise ValueError(
            "Missing input_path argument. For usage, use the --help flag.")
    if not args.output_path:
        raise ValueError(
            "Missing output_path argument. For usage, use the --help flag.")

    info("TensorFlow.js Graph Model Converter\n")
    info("Graph model:    {}".format(args.input_path))
    info("Output:         {}".format(args.output_path))
    info("Target format:  {}".format(args.output_format))
    info("\nConverting....", end=" ")

    start_time = time.perf_counter()

    if args.output_format == common.CLI_FROZEN_MODEL:
        api.graph_model_to_frozen_graph(args.input_path, args.output_path)
    elif args.output_format == common.CLI_SAVED_MODEL:
        api.graph_model_to_saved_model(args.input_path, args.output_path,
                                       args.saved_model_tags)
    else:
        raise ValueError("Unsupported output format: {}".format(
            args.output_format))

    end_time = time.perf_counter()
    info("Done.")
    info("Conversion took {0:.3f}s".format(end_time - start_time))

    return
コード例 #2
0
 def test_graph_model_to_frozen_graph(self):
     """graph_model_to_frozen_graph should save valid frozen graph model"""
     try:
         input_name = testutils.get_path_to(
             testutils.SIMPLE_MODEL_PATH_NAME)
         output_name = os.path.join(tempfile.gettempdir(), 'frozen.pb')
         api.graph_model_to_frozen_graph(input_name, output_name)
         # make sure the output file exists and isn't empty
         self.assertTrue(os.path.exists(output_name))
         self.assertGreater(os.stat(output_name).st_size, 256)
         # file must be a valid protobuf message
         with open(output_name, 'rb') as pb_file:
             graph_def = testutils.GraphDef()
             graph_def.ParseFromString(pb_file.read())
     finally:
         if os.path.exists(output_name):
             os.remove(output_name)
コード例 #3
0
def convert(arguments):
    """
    Convert a TensorflowJS-model to a TensorFlow-model.

    Args:
        arguments: List of command-line arguments
    """
    args = get_arg_parser().parse_args(arguments)
    if args.show_version:
        print(f"\ntfjs_graph_converter {version.VERSION}\n")
        print("Dependency versions:")
        print(f"    tensorflow {tf.version.VERSION}")
        print(f"    tensorflowjs {tfjs.__version__}")
        return

    def info(message, end=None):
        if not args.silence:
            print(message, end=end, flush=True)

    if not args.input_path:
        raise ValueError(
            "Missing input_path argument. For usage, use the --help flag.")
    if not args.output_path:
        raise ValueError(
            "Missing output_path argument. For usage, use the --help flag.")
    if args.output_format == common.CLI_SAVED_MODEL:
        if args.signature_key is not None and args.outputs is None:
            raise ValueError(f'--{common.CLI_SIGNATURE_KEY} requires '
                             f'--{common.CLI_OUTPUTS} to be specified')
        if args.method_name is not None and args.outputs is None:
            raise ValueError(f'--{common.CLI_METHOD_NAME} requires '
                             f'--{common.CLI_OUTPUTS} to be specified')
        if args.rename is not None and args.outputs is None:
            raise ValueError(f'--{common.CLI_RENAME} requires '
                             f'--{common.CLI_OUTPUTS} to be specified')

    info("TensorFlow.js Graph Model Converter\n")
    info(f"Graph model:    {args.input_path}")
    info(f"Output:         {args.output_path}")
    info(f"Target format:  {args.output_format}")
    info("\nConverting....", end=" ")

    start_time = time.perf_counter()

    if args.output_format == common.CLI_FROZEN_MODEL:
        api.graph_model_to_frozen_graph(args.input_path, args.output_path,
                                        args.compat_mode)
    elif args.output_format == common.CLI_SAVED_MODEL:
        api.graph_model_to_saved_model(
            args.input_path,
            args.output_path,
            tags=args.saved_model_tags,
            signature_def_map=_get_signature(args),
            signature_key_map=_get_signature_keys(args),
            compat_mode=args.compat_mode)
    else:
        raise ValueError(f"Unsupported output format: {args.output_format}")

    end_time = time.perf_counter()
    info("Done.")
    info(f"Conversion took {end_time-start_time:.3f}s")