コード例 #1
0
ファイル: tf_loader.py プロジェクト: guschmue/tensorflow-onnx
def freeze_session(sess, input_names=None, output_names=None, get_tables=False):
    """Freezes the state of a session into a pruned computation graph."""
    output_node_names = [i.split(':')[:-1][0] for i in output_names]
    keep_var_names = [i.split(':')[:-1][0] for i in input_names]
    with sess.graph.as_default():
        output_node_names = output_node_names or []
        output_node_names += [v.op.name for v in tf_global_variables()]
        output_node_names += keep_var_names
        graph_def = sess.graph.as_graph_def(add_shapes=True)
        for node in graph_def.node:
            node.device = ""
        graph_def = convert_variables_to_constants(sess, graph_def, output_node_names)
        table_info = get_hash_table_info(graph_def)
        if get_tables:
            initialized_tables = {}
            tf.tables_initializer().run(session=sess)
            for info in table_info:
                h = lookup_ops.hash_table_v2(info.key_dtype, info.val_dtype, shared_name=info.shared_name)
                try:
                    k, v = lookup_ops.lookup_table_export_v2(h, info.key_dtype, info.val_dtype)
                    k, v = sess.run([k, v])
                    initialized_tables[n] = (k, v)
                except Exception:  # pylint: disable=broad-except
                    logger.warning("Could not initialize table with shared_name = %r", n)
            return graph_def, initialized_tables
    return graph_def
コード例 #2
0
ファイル: tf_loader.py プロジェクト: guschmue/tensorflow-onnx
def from_trackable(trackable, concrete_func, inputs, outputs, large_model):
    err_large_model = "model exceeds maximum protobuf size of 2GB. Try setting large_model."

    # Avoid errors due to bug in TF freezing
    removed_resource_to_placeholder, placeholder_to_resource, graph_captures_copy, func_captures_copy = \
        _remove_non_variable_resources_from_captures(concrete_func)

    try:
        frozen_graph = from_function(concrete_func, inputs, outputs, large_model)
    except ValueError as e:
        if any(msg in str(e) for msg in ["exceeds maximum protobuf size of 2GB", "string too long"]):
            raise ValueError(err_large_model)
        raise e

    # We might be returning the concrete_func so let's put it back in working order
    _restore_captured_resources(concrete_func, graph_captures_copy, func_captures_copy)

    table_info = get_hash_table_info(frozen_graph)
    placeholder_to_table_info = {}
    _get_hash_table_info_from_trackable(trackable, table_info,
                                        removed_resource_to_placeholder, placeholder_to_table_info)

    initialized_tables = {}
    for info in table_info:
        if info.shared_name is not None:
            h = lookup_ops.hash_table_v2(info.key_dtype, info.val_dtype, shared_name=info.shared_name)
            n = info.shared_name
        elif info.resource_input in placeholder_to_resource and info.resource_input not in placeholder_to_table_info:
            # We found a lookup op with no corresponding HashTable op, but we can associate the placeholder input
            # from the op with the resource handle from graph captures and make up a shared_name
            h = placeholder_to_resource[info.resource_input]
            n = str(uuid.uuid4()).encode()
            info.shared_name = n
            placeholder_to_table_info[info.resource_input] = info
        else:
            # Found a lookup op but the corresponding HashTable op has already been found and processed.
            continue
        try:
            k, v = lookup_ops.lookup_table_export_v2(h, info.key_dtype, info.val_dtype)
            initialized_tables[n] = (k.numpy(), v.numpy())
        except Exception:  # pylint: disable=broad-except
            logger.warning("Could not initialize table with shared_name = %r", n)

    for placeholder in removed_resource_to_placeholder.values():
        if placeholder not in placeholder_to_table_info:
            logger.error("Could not find table resource to replace placeholder %s", placeholder)

    replace_placeholders_with_tables(frozen_graph, placeholder_to_table_info)

    return frozen_graph, initialized_tables
コード例 #3
0
ファイル: tf_loader.py プロジェクト: APX103/tensorflow-onnx
def from_trackable(trackable, concrete_func, inputs, outputs, large_model):
    err_large_model = "model exceeds maximum protobuf size of 2GB. Try setting large_model."

    # Avoid errors due to bug in TF freezing
    removed_resource_to_placeholder, graph_captures_copy, func_captures_copy = \
        _remove_non_variable_resources_from_captures(concrete_func)

    try:
        frozen_graph = from_function(concrete_func, inputs, outputs, large_model)
    except ValueError as e:
        if any(msg in str(e) for msg in ["exceeds maximum protobuf size of 2GB", "string too long"]):
            raise ValueError(err_large_model)
        raise e

    # We might be returning the concrete_func so let's put it back in working order
    _restore_captured_resources(concrete_func, graph_captures_copy, func_captures_copy)

    table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph)
    placeholder_to_table_info = {}
    _get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, value_dtypes,
                                        removed_resource_to_placeholder, placeholder_to_table_info)

    initialized_tables = {}
    for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes):
        h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n)
        try:
            k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype)
            initialized_tables[n] = (k.numpy(), v.numpy())
        except Exception:  # pylint: disable=broad-except
            logger.warning("Could not initialize table with shared_name = %r", n)

    for placeholder in removed_resource_to_placeholder.values():
        if placeholder not in placeholder_to_table_info:
            logger.error("Could not find table resource to replace placeholder %s", placeholder)

    replace_placeholders_with_tables(frozen_graph, placeholder_to_table_info)

    return frozen_graph, initialized_tables
コード例 #4
0
def _from_saved_model_v2(model_path, input_names, output_names, tag,
                         signature_def, concrete_function_index, large_model):
    """Load tensorflow graph from saved_model."""

    wrn_no_tag = "'--tag' not specified for saved_model. Using --tag serve"
    wrn_empty_tag = "'--tag' value is empty string. Using tag =[[]]"
    wrn_sig_1 = "'--signature_def' not specified, using first signature: %s"
    err_many_sig = "Cannot load multiple signature defs in TF2.x: %s"
    err_no_call = "Model doesn't contain usable concrete functions under  __call__. Try --signature-def instead."
    err_index = "Invalid concrete_function value: %i. Valid values are [0 to %i]"
    err_no_sig = "No signatures found in model. Try --concrete_function instead."
    err_sig_nomatch = "Specified signature not in model %s"
    err_large_model = "model exceeds maximum protobuf size of 2GB. Try running with --large_model flag."

    if tag is None:
        tag = ['serve']
        logger.warning(wrn_no_tag)

    if tag == '':
        tag = [[]]
        logger.warning(wrn_empty_tag)

    utils.make_sure(len(signature_def) < 2, err_many_sig, str(signature_def))
    imported = tf.saved_model.load(model_path, tags=tag)  # pylint: disable=no-value-for-parameter

    all_sigs = imported.signatures.keys()
    valid_sigs = [s for s in all_sigs if not s.startswith("_")]
    logger.info("Signatures found in model: %s",
                "[" + ",".join(valid_sigs) + "].")

    concrete_func = None
    if concrete_function_index is not None:
        utils.make_sure(hasattr(imported, "__call__"), err_no_call)
        utils.make_sure(
            concrete_function_index < len(
                imported.__call__.concrete_functions), err_index,
            concrete_function_index,
            len(imported.__call__.concrete_functions) - 1)
        args, kwargs = imported.__call__.concrete_functions[
            concrete_function_index].structured_input_signature
        concrete_func = imported.__call__.get_concrete_function(
            *args, **kwargs)
    elif signature_def:
        utils.make_sure(signature_def[0] in valid_sigs, err_sig_nomatch,
                        signature_def[0])
        concrete_func = imported.signatures[signature_def[0]]
    else:
        utils.make_sure(len(valid_sigs) > 0, err_no_sig)
        logger.warning(wrn_sig_1, valid_sigs[0])
        concrete_func = imported.signatures[valid_sigs[0]]

    tensors_to_rename = {}
    if input_names is None:
        inputs = [
            tensor.name for tensor in concrete_func.inputs
            if tensor.dtype != tf.dtypes.resource
        ]
        if concrete_func.structured_input_signature is not None:
            args, kwargs = concrete_func.structured_input_signature
            structured_inputs = [
                t.name for t in args if isinstance(t, tf.TensorSpec)
            ] + sorted(kwargs.keys())
            structured_inputs = set(inp + ":0" for inp in structured_inputs)
            if any(inp in structured_inputs for inp in inputs):
                inputs = [inp for inp in inputs if inp in structured_inputs]
    else:
        inputs = input_names

    if output_names is None:
        outputs = [
            tensor.name for tensor in concrete_func.outputs
            if tensor.dtype != tf.dtypes.resource
        ]
        if isinstance(concrete_func.structured_outputs, dict):
            # outputs are sorted, sort structured_outputs the same way
            structured_outputs = sorted(
                concrete_func.structured_outputs.keys())
            tensors_to_rename.update(zip(outputs, structured_outputs))
            logger.info("Output names: %r", structured_outputs)
        else:
            logger.info("Output names: %r", outputs)
    else:
        outputs = output_names
        logger.info(
            "Outputs not left as None; will use provided names not structured output names."
        )

    # Avoid errors due to bug in TF freezing
    removed_resource_to_placeholder, graph_captures_copy, func_captures_copy = \
        _remove_non_variable_resources_from_captures(concrete_func)

    try:
        frozen_graph = from_function(concrete_func, inputs, outputs,
                                     large_model)
    except ValueError as e:
        if any(msg in str(e) for msg in
               ["exceeds maximum protobuf size of 2GB", "string too long"]):
            raise ValueError(err_large_model)
        raise e

    # We might be returning the concrete_func so let's put it back in working order
    _restore_captured_resources(concrete_func, graph_captures_copy,
                                func_captures_copy)

    table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph)
    placeholder_to_table_info = {}
    _get_hash_table_info_from_trackable(imported, table_names, key_dtypes,
                                        value_dtypes,
                                        removed_resource_to_placeholder,
                                        placeholder_to_table_info)

    initialized_tables = {}
    for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes):
        h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n)
        try:
            k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype)
            initialized_tables[n] = (k.numpy(), v.numpy())
        except Exception:  # pylint: disable=broad-except
            logger.warning("Could not initialize table with shared_name = %r",
                           n)

    for placeholder in removed_resource_to_placeholder.values():
        if placeholder not in placeholder_to_table_info:
            logger.error(
                "Could not find table resource to replace placeholder %s",
                placeholder)

    replace_placeholders_with_tables(frozen_graph, placeholder_to_table_info)

    return frozen_graph, inputs, outputs, concrete_func, imported, initialized_tables, tensors_to_rename
コード例 #5
0
def _from_saved_model_v1(sess, model_path, input_names, output_names, tag,
                         signature_names):
    """Load tensorflow graph from saved_model."""

    wrn_no_tag = "'--tag' not specified for saved_model. Using --tag serve"
    wrn_empty_tag = "'--tag' value is empty string. Using tags = []"
    wrn_empty_sig = "'--signature_def' not provided. Using all signatures."

    if tag is None:
        tag = [tf.saved_model.tag_constants.SERVING]
        logger.warning(wrn_no_tag)

    if not signature_names:
        logger.warning(wrn_empty_sig)

    if tag == '':
        tag = []
        logger.warning(wrn_empty_tag)

    if not isinstance(tag, list):
        tag = [tag]

    imported = tf.saved_model.loader.load(sess, tag, model_path)
    signatures = []
    for k in imported.signature_def.keys():
        if k in signature_names or (not signature_names
                                    and not k.startswith("_")):
            signatures.append(k)
    try:
        from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
        # pylint: disable=unnecessary-lambda
        get_signature_def = lambda meta_graph_def, k: \
            signature_def_utils.get_signature_def_by_key(meta_graph_def, k)
    except ImportError:
        # TF1.12 changed the api
        get_signature_def = lambda meta_graph_def, k: meta_graph_def.signature_def[
            k]

    if input_names is None:
        input_names = []
        for k in signatures:
            inputs_tensor_info = get_signature_def(imported, k).inputs
            for _, input_tensor in inputs_tensor_info.items():
                if input_tensor.name not in input_names:
                    input_names.append(input_tensor.name)
    tensors_to_rename = {}
    if output_names is None:
        output_names = []
        for k in signatures:
            outputs_tensor_info = get_signature_def(imported, k).outputs
            for structured_name, output_tensor in outputs_tensor_info.items():
                if output_tensor.name not in output_names:
                    output_names.append(output_tensor.name)
                    tensors_to_rename[output_tensor.name] = structured_name
    frozen_graph = freeze_session(sess,
                                  input_names=input_names,
                                  output_names=output_names)
    table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph)
    initialized_tables = {}
    tf.tables_initializer().run()
    for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes):
        h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n)
        try:
            k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype)
            k, v = sess.run([k, v])
            initialized_tables[n] = (k, v)
        except Exception:  # pylint: disable=broad-except
            logger.warning("Could not initialize table with shared_name = %r",
                           n)
    return frozen_graph, input_names, output_names, initialized_tables, tensors_to_rename
コード例 #6
0
    def freeze_and_run_tf(self, func, feed_dict, outputs, as_session,
                          premade_placeholders, large_model, constant_fold):
        np.random.seed(1)  # Make it reproducible.
        clean_feed_dict = {utils.node_name(k): v for k, v in feed_dict.items()}
        if is_tf2() and not as_session:
            #
            # use eager to execute the tensorflow func
            #
            # numpy doesn't work for all ops, make it tf.Tensor()
            input_tensors = [
                tf.TensorSpec(shape=v.shape,
                              dtype=tf.as_dtype(v.dtype),
                              name=utils.node_name(k))
                for k, v in feed_dict.items()
            ]
            input_list = [
                tf.convert_to_tensor(v,
                                     dtype=tf.as_dtype(v.dtype),
                                     name=utils.node_name(k))
                for k, v in feed_dict.items()
            ]
            tf.random.set_seed(1)
            result = func(*input_list)
            if isinstance(result, (list, tuple)):
                # list or tuple
                result = [x.numpy() for x in result]
            else:
                # single result
                result = [result.numpy()]

            # now make the eager functions a graph
            concrete_func = tf.function(func,
                                        input_signature=tuple(input_tensors))
            concrete_func = concrete_func.get_concrete_function()
            graph_def = from_function(concrete_func,
                                      input_names=list(feed_dict.keys()),
                                      output_names=outputs,
                                      large_model=large_model)
            initialized_tables = None
        else:
            #
            # use graph to execute the tensorflow func
            #
            with tf_session() as sess:
                tf_set_random_seed(1)
                input_list = []
                if not premade_placeholders:
                    for k, v in clean_feed_dict.items():
                        input_list.append(
                            tf_placeholder(name=k,
                                           shape=v.shape,
                                           dtype=tf.as_dtype(v.dtype)))
                func(*input_list)
                variables_lib.global_variables_initializer().run()
                tf_tables_initializer().run()

                output_dict = []
                for out_name in outputs:
                    output_dict.append(sess.graph.get_tensor_by_name(out_name))
                result = sess.run(output_dict, feed_dict=feed_dict)
                graph_def = freeze_session(sess,
                                           input_names=list(feed_dict.keys()),
                                           output_names=outputs)
                table_names, key_dtypes, value_dtypes = get_hash_table_info(
                    graph_def)
                initialized_tables = {}
                for n, k_dtype, val_dtype in zip(table_names, key_dtypes,
                                                 value_dtypes):
                    h = lookup_ops.hash_table_v2(k_dtype,
                                                 val_dtype,
                                                 shared_name=n)
                    k, v = lookup_ops.lookup_table_export_v2(
                        h, k_dtype, val_dtype)
                    initialized_tables[n] = (sess.run(k), sess.run(v))

            tf_reset_default_graph()
            with tf_session() as sess:
                tf.import_graph_def(graph_def, name='')
                graph_def = tf_optimize(list(feed_dict.keys()),
                                        outputs,
                                        graph_def,
                                        fold_constant=constant_fold)

        model_path = os.path.join(
            self.test_data_directory,
            self._testMethodName + "_after_tf_optimize.pb")
        utils.save_protobuf(model_path, graph_def)
        self.logger.debug("created file  %s", model_path)
        return result, graph_def, initialized_tables
コード例 #7
0
    def run_test_case(self, func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07, atol=1e-5,
                      convert_var_to_const=True, constant_fold=True, check_value=True, check_shape=True,
                      check_dtype=True, process_args=None, onnx_feed_dict=None, graph_validator=None, as_session=False,
                      large_model=False, premade_placeholders=False):
        # optional - passed to process_tf_graph
        if process_args is None:
            process_args = {}
        # optional - pass distinct feed_dict to onnx runtime
        if onnx_feed_dict is None:
            onnx_feed_dict = feed_dict
        input_names_with_port = list(feed_dict)
        tf_reset_default_graph()
        graph_def = None
        initialized_tables = None

        np.random.seed(1)  # Make it reproducible.
        clean_feed_dict = {utils.node_name(k): v for k, v in feed_dict.items()}
        if is_tf2() and not as_session:
            #
            # use eager to execute the tensorflow func
            #
            # numpy doesn't work for all ops, make it tf.Tensor()
            input_tensors = [tf.TensorSpec(shape=v.shape, dtype=tf.as_dtype(v.dtype), name=utils.node_name(k))
                             for k, v in feed_dict.items()]
            input_list = [tf.convert_to_tensor(v, dtype=tf.as_dtype(v.dtype), name=utils.node_name(k))
                          for k, v in feed_dict.items()]
            tf.random.set_seed(1)
            expected = func(*input_list)
            if isinstance(expected, (list, tuple)):
                # list or tuple
                expected = [x.numpy() for x in expected]
            else:
                # single result
                expected = [expected.numpy()]

            # now make the eager functions a graph
            concrete_func = tf.function(func, input_signature=tuple(input_tensors))
            concrete_func = concrete_func.get_concrete_function()
            graph_def = from_function(concrete_func,
                                      input_names=list(feed_dict.keys()),
                                      output_names=output_names_with_port,
                                      large_model=large_model)
        else:
            #
            # use graph to execute the tensorflow func
            #
            with tf_session() as sess:
                tf_set_random_seed(1)
                input_list = []
                if not premade_placeholders:
                    for k, v in clean_feed_dict.items():
                        input_list.append(tf_placeholder(name=k, shape=v.shape, dtype=tf.as_dtype(v.dtype)))
                func(*input_list)
                variables_lib.global_variables_initializer().run()
                tf_tables_initializer().run()

                output_dict = []
                for out_name in output_names_with_port:
                    output_dict.append(sess.graph.get_tensor_by_name(out_name))
                expected = sess.run(output_dict, feed_dict=feed_dict)
                graph_def = freeze_session(sess,
                                           input_names=list(feed_dict.keys()),
                                           output_names=output_names_with_port)
                table_names, key_dtypes, value_dtypes = get_hash_table_info(graph_def)
                initialized_tables = {}
                for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes):
                    h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n)
                    k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype)
                    initialized_tables[n] = (sess.run(k), sess.run(v))

            tf_reset_default_graph()
            with tf_session() as sess:
                tf.import_graph_def(graph_def, name='')
                graph_def = tf_optimize(list(feed_dict.keys()), output_names_with_port,
                                        graph_def, fold_constant=constant_fold)

        tf_reset_default_graph()
        with tf_session() as sess:
            const_node_values = None
            if large_model:
                const_node_values = compress_graph_def(graph_def)
            tf.import_graph_def(graph_def, name='')

            if self.config.is_debug_mode:
                model_path = os.path.join(self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb")
                utils.save_protobuf(model_path, graph_def)
                self.logger.debug("created file  %s", model_path)

            g = process_tf_graph(sess.graph, opset=self.config.opset,
                                 input_names=list(feed_dict.keys()),
                                 output_names=output_names_with_port,
                                 target=self.config.target,
                                 const_node_values=const_node_values,
                                 initialized_tables=initialized_tables,
                                 **process_args)
            g = optimizer.optimize_graph(g, catch_errors=False)
            actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model)

        for expected_val, actual_val in zip(expected, actual):
            if check_value:
                if expected_val.dtype == np.object:
                    decode = np.vectorize(lambda x: x.decode('UTF-8'))
                    expected_val_str = decode(expected_val)
                    self.assertAllEqual(expected_val_str, actual_val)
                else:
                    self.assertAllClose(expected_val, actual_val, rtol=rtol, atol=atol)
            if check_dtype:
                self.assertEqual(expected_val.dtype, actual_val.dtype)
            # why need shape checke: issue when compare [] with scalar
            # https://github.com/numpy/numpy/issues/11071
            if check_shape:
                self.assertEqual(expected_val.shape, actual_val.shape)

        if graph_validator:
            self.assertTrue(graph_validator(g))

        return g
コード例 #8
0
def _from_saved_model_v2(model_path, input_names, output_names, tag,
                         signature_def, concrete_function_index, large_model):
    """Load tensorflow graph from saved_model."""

    wrn_no_tag = "'--tag' not specified for saved_model. Using --tag serve"
    wrn_empty_tag = "'--tag' value is empty string. Using tag =[[]]"
    wrn_sig_1 = "'--signature_def' not specified, using first signature: %s"
    err_many_sig = "Cannot load multiple signature defs in TF2.x: %s"
    err_no_call = "Model doesn't contain usable concrete functions under  __call__. Try --signature-def instead."
    err_index = "Invalid concrete_function value: %i. Valid values are [0 to %i]"
    err_no_sig = "No signatures found in model. Try --concrete_function instead."
    err_sig_nomatch = "Specified signature not in model %s"
    err_large_model = "model exceeds maximum protobuf size of 2GB. Try running with --large_model flag."

    if tag is None:
        tag = ['serve']
        logger.warning(wrn_no_tag)

    if tag == '':
        tag = [[]]
        logger.warning(wrn_empty_tag)

    utils.make_sure(len(signature_def) < 2, err_many_sig, str(signature_def))
    imported = tf.saved_model.load(model_path, tags=tag)  # pylint: disable=no-value-for-parameter

    all_sigs = imported.signatures.keys()
    valid_sigs = [s for s in all_sigs if not s.startswith("_")]
    logger.info("Signatures found in model: %s",
                "[" + ",".join(valid_sigs) + "].")

    concrete_func = None
    if concrete_function_index is not None:
        utils.make_sure(hasattr(imported, "__call__"), err_no_call)
        utils.make_sure(
            concrete_function_index < len(
                imported.__call__.concrete_functions), err_index,
            concrete_function_index,
            len(imported.__call__.concrete_functions) - 1)
        sig = imported.__call__.concrete_functions[
            concrete_function_index].structured_input_signature[0]
        concrete_func = imported.__call__.get_concrete_function(*sig)
    elif signature_def:
        utils.make_sure(signature_def[0] in valid_sigs, err_sig_nomatch,
                        signature_def[0])
        concrete_func = imported.signatures[signature_def[0]]
    else:
        utils.make_sure(len(valid_sigs) > 0, err_no_sig)
        logger.warning(wrn_sig_1, valid_sigs[0])
        concrete_func = imported.signatures[valid_sigs[0]]

    inputs = [
        tensor.name for tensor in concrete_func.inputs
        if tensor.dtype != tf.dtypes.resource
    ]
    outputs = [
        tensor.name for tensor in concrete_func.outputs
        if tensor.dtype != tf.dtypes.resource
    ]

    # filter by user specified inputs/outputs
    if input_names:
        inputs = list(set(input_names) & set(inputs))
    if output_names:
        outputs = list(set(output_names) & set(outputs))

    try:
        frozen_graph = from_function(concrete_func, inputs, outputs,
                                     large_model)
    except ValueError as e:
        if any(msg in str(e) for msg in
               ["exceeds maximum protobuf size of 2GB", "string too long"]):
            raise ValueError(err_large_model)
        raise e

    table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph)
    initialized_tables = {}
    for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes):
        h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n)
        try:
            k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype)
            initialized_tables[n] = (k.numpy(), v.numpy())
        except Exception:  # pylint: disable=broad-except
            logger.warning("Could not initialize table with shared_name = %r",
                           n)

    return frozen_graph, inputs, outputs, concrete_func, imported, initialized_tables
コード例 #9
0
def _from_saved_model_v2(model_path, input_names, output_names, tag,
                         signature_def, concrete_function_index, large_model):
    """Load tensorflow graph from saved_model."""

    wrn_no_tag = "'--tag' not specified for saved_model. Using --tag serve"
    wrn_empty_tag = "'--tag' value is empty string. Using tag =[[]]"
    wrn_sig_1 = "'--signature_def' not specified, using first signature: %s"
    err_many_sig = "Cannot load multiple signature defs in TF2.x: %s"
    err_no_call = "Model doesn't contain usable concrete functions under  __call__. Try --signature-def instead."
    err_index = "Invalid concrete_function value: %i. Valid values are [0 to %i]"
    err_no_sig = "No signatures found in model. Try --concrete_function instead."
    err_sig_nomatch = "Specified signature not in model %s"
    err_large_model = "model exceeds maximum protobuf size of 2GB. Try running with --large_model flag."

    if tag is None:
        tag = ['serve']
        logger.warning(wrn_no_tag)

    if tag == '':
        tag = [[]]
        logger.warning(wrn_empty_tag)

    utils.make_sure(len(signature_def) < 2, err_many_sig, str(signature_def))
    imported = tf.saved_model.load(model_path, tags=tag)  # pylint: disable=no-value-for-parameter

    all_sigs = imported.signatures.keys()
    valid_sigs = [s for s in all_sigs if not s.startswith("_")]
    logger.info("Signatures found in model: %s",
                "[" + ",".join(valid_sigs) + "].")

    concrete_func = None
    if concrete_function_index is not None:
        utils.make_sure(hasattr(imported, "__call__"), err_no_call)
        utils.make_sure(
            concrete_function_index < len(
                imported.__call__.concrete_functions), err_index,
            concrete_function_index,
            len(imported.__call__.concrete_functions) - 1)
        sig = imported.__call__.concrete_functions[
            concrete_function_index].structured_input_signature[0]
        concrete_func = imported.__call__.get_concrete_function(*sig)
    elif signature_def:
        utils.make_sure(signature_def[0] in valid_sigs, err_sig_nomatch,
                        signature_def[0])
        concrete_func = imported.signatures[signature_def[0]]
    else:
        utils.make_sure(len(valid_sigs) > 0, err_no_sig)
        logger.warning(wrn_sig_1, valid_sigs[0])
        concrete_func = imported.signatures[valid_sigs[0]]

    inputs = [
        tensor.name for tensor in concrete_func.inputs
        if tensor.dtype != tf.dtypes.resource
    ]
    outputs = [
        tensor.name for tensor in concrete_func.outputs
        if tensor.dtype != tf.dtypes.resource
    ]

    # filter by user specified inputs/outputs
    if input_names:
        inputs = list(set(input_names) & set(inputs))
    if output_names:
        outputs = list(set(output_names) & set(outputs))

    # Avoid errors due to bug in TF freezing
    removed_resource_to_placeholder, graph_captures_copy, func_captures_copy = \
        _remove_non_variable_resources_from_captures(concrete_func)

    try:
        frozen_graph = from_function(concrete_func, inputs, outputs,
                                     large_model)
    except ValueError as e:
        if any(msg in str(e) for msg in
               ["exceeds maximum protobuf size of 2GB", "string too long"]):
            raise ValueError(err_large_model)
        raise e

    # We might be returning the concrete_func so let's put it back in working order
    _restore_captured_resources(concrete_func, graph_captures_copy,
                                func_captures_copy)

    table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph)
    placeholder_to_table_info = {}
    if hasattr(imported, '_table') and hasattr(imported._table,
                                               '_create_resource'):  # pylint: disable=protected-access
        # Add tables from saved_model table initializers
        # pylint: disable=protected-access
        initializer = imported._table._create_resource.concrete_functions[
            0].function_def
        new_names, new_k_dtypes, new_v_dtypes = get_hash_table_info(
            initializer.node_def)
        table_names.extend(new_names)
        key_dtypes.extend(new_k_dtypes)
        value_dtypes.extend(new_v_dtypes)
        table_handle = id(imported._table.resource_handle)
        if table_handle in removed_resource_to_placeholder and len(
                new_names) == 1:
            table_info = (new_names[0], new_k_dtypes[0], new_v_dtypes[0])
            placeholder_to_table_info[
                removed_resource_to_placeholder[table_handle]] = table_info

    initialized_tables = {}
    for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes):
        h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n)
        try:
            k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype)
            initialized_tables[n] = (k.numpy(), v.numpy())
        except Exception:  # pylint: disable=broad-except
            logger.warning("Could not initialize table with shared_name = %r",
                           n)

    replace_placeholders_with_tables(frozen_graph, placeholder_to_table_info)

    return frozen_graph, inputs, outputs, concrete_func, imported, initialized_tables
コード例 #10
0
    def freeze_and_run_tf(self, func, feed_dict, outputs, as_session,
                          premade_placeholders, large_model):
        np.random.seed(1)  # Make it reproducible.
        clean_feed_dict = {utils.node_name(k): v for k, v in feed_dict.items()}
        if is_tf2() and not as_session:
            #
            # use eager to execute the tensorflow func
            #
            # numpy doesn't work for all ops, make it tf.Tensor()
            input_tensors = [
                tf.TensorSpec(shape=v.shape,
                              dtype=tf.as_dtype(v.dtype),
                              name=utils.node_name(k))
                for k, v in feed_dict.items()
            ]
            input_list = [
                tf.convert_to_tensor(v,
                                     dtype=tf.as_dtype(v.dtype),
                                     name=utils.node_name(k))
                for k, v in feed_dict.items()
            ]
            tf.random.set_seed(1)
            result = func(*input_list)
            if isinstance(result, (list, tuple)):
                # list or tuple
                result = [x.numpy() for x in result]
            else:
                # single result
                result = [result.numpy()]

            # now make the eager functions a graph
            concrete_func = tf.function(func,
                                        input_signature=tuple(input_tensors))
            concrete_func = concrete_func.get_concrete_function()
            graph_def = from_function(concrete_func,
                                      input_names=list(feed_dict.keys()),
                                      output_names=outputs,
                                      large_model=large_model)
            initialized_tables = None
        else:
            #
            # use graph to execute the tensorflow func
            #
            with tf_session() as sess:
                tf_set_random_seed(1)
                input_list = []
                if not premade_placeholders:
                    for k, v in clean_feed_dict.items():
                        input_list.append(
                            tf_placeholder(name=k,
                                           shape=v.shape,
                                           dtype=tf.as_dtype(v.dtype)))
                func(*input_list)
                variables_lib.global_variables_initializer().run()
                tf_tables_initializer().run()

                output_dict = []
                for out_name in outputs:
                    output_dict.append(sess.graph.get_tensor_by_name(out_name))
                result = sess.run(output_dict, feed_dict=feed_dict)
                graph_def = freeze_session(sess,
                                           input_names=list(feed_dict.keys()),
                                           output_names=outputs)
                table_info = get_hash_table_info(graph_def)
                initialized_tables = {}
                for info in table_info:
                    if info.shared_name is None:
                        continue
                    h = lookup_ops.hash_table_v2(info.key_dtype,
                                                 info.val_dtype,
                                                 shared_name=info.shared_name)
                    k, v = lookup_ops.lookup_table_export_v2(
                        h, info.key_dtype, info.val_dtype)
                    initialized_tables[info.shared_name] = (sess.run(k),
                                                            sess.run(v))

            tf_reset_default_graph()
            with tf_session() as sess:
                tf.import_graph_def(graph_def, name='')
                graph_def = tf_optimize(list(feed_dict.keys()), outputs,
                                        graph_def)

        return result, graph_def, initialized_tables