コード例 #1
0
    def _generate_ordered_node_ids(self):
        """Orders the node ids so that dependencies appear first."""
        if self._filtered_nodes is None:
            unordered_ids = range(len(self._proto.nodes))
        else:
            unordered_ids = list(self._filtered_nodes)

        dependency_map = {}
        for node_id in unordered_ids:
            deps = dependency_map[node_id] = []
            if self._loaded_nodes.get(node_id) is not None:
                # Deps are only used if the node has not been created.
                continue
            for reference in self._proto.nodes[node_id].dependencies:
                dep = reference.node_id
                deps.append(dep)
                if self._filtered_nodes is not None and dep not in self._filtered_nodes:
                    raise ValueError(
                        "Unable to partially load SavedModel since the specified filter "
                        "does not include all deserialization dependencies. Please "
                        "include this path in the filter: "
                        f"{self._pretty_printer.node_names[dep]}")

        try:
            return list(trackable_utils.order_by_dependency(dependency_map))
        except trackable_utils.CyclicDependencyError:
            # This should not happen since there is already a validation for cycles
            # when saving, but raise an error just in case.
            raise ValueError(
                "Encountered a cycle in the deserialization dependencies"
                "in the SavedModel. This is extremely unexpected, please"
                "file a bug and make sure you are not manually modifying"
                " the SavedModel.")
コード例 #2
0
ファイル: load.py プロジェクト: gglin001/tensorflow
    def _generate_ordered_node_ids(self):
        """Orders the node ids so that dependencies appear first."""
        if self._filtered_nodes is None:
            unordered_ids = range(len(self._proto.nodes))
        else:
            unordered_ids = list(self._filtered_nodes)

        # Maps node ids -> list of dependencies (ids of other nodes that must be
        # loaded before it).
        dependency_map = collections.defaultdict(list)
        for node_id in unordered_ids:
            deps = dependency_map[node_id]
            if self._loaded_nodes.get(node_id) is not None:
                # Deps are only used if the node has not been created.
                continue
            proto = self._proto.nodes[node_id]
            for dep in set(self._get_node_dependencies(proto).values()):
                deps.append(dep)
                if self._filtered_nodes is not None and dep not in self._filtered_nodes:
                    raise ValueError(
                        "Unable to partially load SavedModel since the specified filter "
                        "does not include all required objects for loading (e.g. "
                        "variables used in functions or deserialization dependencies). "
                        "Please include this path in the filter: "
                        f"{self._pretty_printer.node_names[dep]}")

            # Add optimizer slot variable to dependency map.
            prev_slot = None
            for slot_variable_proto in proto.slot_variables:
                slot_variable_node_id = slot_variable_proto.slot_variable_node_id
                # The optimizer and original variable must be created before the slot
                # variable, since the slot variable is generated using the Optimizer's
                # add_slot API.
                slot_deps = dependency_map[slot_variable_node_id]
                slot_deps.append(node_id)
                slot_deps.append(slot_variable_proto.original_variable_node_id)

                if prev_slot is not None:
                    # Add previous slot to deps so that the optimizer slot variables are
                    # added in order. The ordering is needed because the slot name and
                    # variable are both added to ordered lists, which are exposed to the
                    # user via `Optimizer.get_slot_names()` and `Optimizer.weights`.
                    # TODO(kathywu): Maybe enforce some sort of deterministic ordering in
                    # `order_by_dependency` to avoid doing this?
                    slot_deps.append(prev_slot)
                prev_slot = slot_variable_node_id
        try:
            return list(trackable_utils.order_by_dependency(dependency_map))
        except trackable_utils.CyclicDependencyError:
            # This should not happen since there is already a validation for cycles
            # when saving, but raise an error just in case.
            raise ValueError(
                "Encountered a cycle in the deserialization dependencies"
                "in the SavedModel. This is extremely unexpected, please"
                "file a bug and make sure you are not manually modifying"
                " the SavedModel.")
コード例 #3
0
    def test_order_by_dependency(self):
        """Tests order_by_dependency correctness."""

        # Visual graph (vertical lines point down, so 1 depends on 2):
        #    1
        #  /   \
        # 2 --> 3 <-- 4
        #       |
        #       5
        # One possible order: [5, 3, 4, 2, 1]
        dependencies = {1: [2, 3], 2: [3], 3: [5], 4: [3], 5: []}

        sorted_arr = list(trackable_utils.order_by_dependency(dependencies))
        indices = {x: sorted_arr.index(x) for x in range(1, 6)}
        self.assertEqual(indices[5], 0)
        self.assertEqual(indices[3], 1)
        self.assertGreater(indices[1], indices[2])  # 2 must appear before 1
コード例 #4
0
 def test_order_by_dependency_invalid_map(self):
     with self.assertRaisesRegex(
             ValueError,
             "Found values in the dependency map which are not keys"):
         trackable_utils.order_by_dependency({1: [2]})
コード例 #5
0
 def test_order_by_no_dependency(self):
     sorted_arr = list(
         trackable_utils.order_by_dependency({x: []
                                              for x in range(15)}))
     self.assertEqual(set(sorted_arr), set(range(15)))