Пример #1
0
  def test_soft_matching(self):

    @def_function.function(
        input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)])
    def func(x):
      return 2 * x

    root = tracking.Checkpointable()
    root.f = func

    self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy())
    self.assertAllEqual([2, 4], root.f(constant_op.constant([1, 2])).numpy())

    self.assertEqual(
        1, len(function_serialization.list_all_concrete_functions(root.f)))

    imported = self.cycle(root)

    with self.assertRaises(AssertionError):
      # We cannot call the function with a constant of shape ().
      self.assertEqual(7, imported.f(constant_op.constant(2)).numpy())

    # TODO(vbardiovsky): When classes are revived with input_signatures, we
    # should also check that the calls below are not generating any more
    # concrete functions.
    self.assertAllEqual([2, 4, 6, 8],
                        imported.f(constant_op.constant([1, 2, 3, 4])).numpy())
    self.assertAllEqual([2, 4, 6],
                        imported.f(constant_op.constant([1, 2, 3])).numpy())
Пример #2
0
    def __init__(self, root):
        checkpointable_objects, node_ids, slot_variables = util.find_objects(
            root)
        self.nodes = checkpointable_objects
        self.node_ids = node_ids
        self.slot_variables = slot_variables
        self.polymorphic_functions = util.ObjectIdentityDictionary()

        # Also add polymorphic functions as nodes.
        for obj in self.nodes:
            self.polymorphic_functions[obj] = self._list_polymorphic_functions(
                obj)
            for function in self.polymorphic_functions[obj].values():
                if function not in self.node_ids:
                    self.node_ids[function] = len(self.nodes)
                    self.nodes.append(function)
                # Force listing the concrete functions for the side effects:
                #  - populate the cache for polymorphic functions that have an
                #  input_signature and have not been called.
                #  - force side effects of creation of concrete functions, e.g. create
                #  variables on first run.
                function_serialization.list_all_concrete_functions(function)
Пример #3
0
def _fill_meta_graph_def(meta_graph_def, obj, signature_functions,
                         object_saver):
    """Generates a MetaGraph which calls `signature_functions`.

  Args:
    meta_graph_def: The MetaGraphDef proto to fill.
    obj: The checkpointable object being exported.
    signature_functions: A dictionary mapping signature keys to concrete
      functions containing signatures to add to the MetaGraph.
    object_saver: A CheckpointableSaver to add to the MetaGraph.

  Returns:
    An _AssetInfo, which contains information to help creating the SavedModel.
  """
    signatures = {}
    # List objects from the eager context to make sure Optimizers give us the
    # right Graph-dependent variables.
    accessible_objects = util.list_objects(obj)
    resource_initializer_functions = _trace_resource_initializers(
        accessible_objects)
    exported_graph = ops.Graph()
    resource_initializer_ops = []
    with exported_graph.as_default():
        object_map, resource_map, asset_info = _map_resources(
            accessible_objects)
        for resource_initializer_function in resource_initializer_functions:
            asset_dependencies = []
            for capture in resource_initializer_function.graph.external_captures:
                asset_initializer = asset_info.asset_initializers_by_resource.get(
                    capture, None)
                if asset_initializer is not None:
                    asset_dependencies.append(asset_initializer)
            with ops.control_dependencies(asset_dependencies):
                resource_initializer_ops.append(
                    _call_function_with_mapped_captures(
                        resource_initializer_function, [], resource_map))
        with ops.control_dependencies(resource_initializer_ops):
            init_op = control_flow_ops.no_op()
        # Add the same op to the main_op collection and to the init_op
        # signature. The collection is for compatibility with older loader APIs;
        # only one will be executed.
        meta_graph_def.collection_def[
            constants.MAIN_OP_KEY].node_list.value.append(init_op.name)
        meta_graph_def.signature_def[constants.INIT_OP_SIGNATURE_KEY].CopyFrom(
            signature_def_utils.op_signature_def(
                init_op, constants.INIT_OP_SIGNATURE_KEY))

    # Saving an object-based checkpoint again gathers variables. We need to do the
    # gathering from the eager context so Optimizers save the right set of
    # variables, but want any operations associated with the save/restore to be in
    # the exported graph (thus the `to_graph` argument).
    saver = object_saver.freeze(object_map=object_map, to_graph=exported_graph)

    # We must resolve the concrete function to add to MetaGraph while in eager
    # mode.
    concrete_functions = []
    for accessible_object in accessible_objects:
        for function in function_serialization.list_all_polymorphic_functions(
                accessible_object).values():
            concrete_functions.extend(
                function_serialization.list_all_concrete_functions(function))

    with exported_graph.as_default():
        signatures = _generate_signatures(signature_functions, resource_map)
        for concrete_function in concrete_functions:
            concrete_function.add_to_graph()
        saver_def = saver.to_proto()
        meta_graph_def.saver_def.CopyFrom(saver_def)
    graph_def = exported_graph.as_graph_def(add_shapes=True)
    # Clean reference cycles so repeated export()s don't make work for the garbage
    # collector.
    ops.dismantle_graph(exported_graph)

    meta_graph_def.graph_def.CopyFrom(graph_def)
    meta_graph_def.meta_info_def.tags.append(tag_constants.SERVING)
    meta_graph_def.asset_file_def.extend(asset_info.asset_defs)
    for signature_key, signature in signatures.items():
        meta_graph_def.signature_def[signature_key].CopyFrom(signature)
    meta_graph.strip_graph_default_valued_attrs(meta_graph_def)
    return asset_info