Esempio n. 1
0
    def _graph_def_from_concrete_fn(self, cfs):
        if len(cfs) != 1:
            raise NotImplementedError(
                "Only a single concrete function is supported.")

        if _get_version(_tf.__version__) >= _StrictVersion("2.2.0"):
            frozen_fn = _convert_variables_to_constants_v2(
                cfs[0], lower_control_flow=False, aggressive_inlining=True)
        else:
            frozen_fn = _convert_variables_to_constants_v2(
                cfs[0], lower_control_flow=False)
        graph_def = frozen_fn.graph.as_graph_def(add_shapes=True)

        # run a Grappler's constant folding pass.
        fn_inputs = [
            t for t in frozen_fn.inputs if t.dtype != _dtypes.resource
        ]
        grappler_optimizers_list = self._get_grappler_optimizers_list()
        graph_def = _run_graph_optimizations(
            graph_def,
            fn_inputs,
            frozen_fn.outputs,
            config=_get_grappler_config(grappler_optimizers_list),
            graph=frozen_fn.graph,
        )
        return graph_def
Esempio n. 2
0
    def _graph_def_from_concrete_fn(cfs):
        if len(cfs) != 1:
            raise NotImplementedError(
                "Only a single concrete function is supported.")

        frozen_fn = _convert_variables_to_constants_v2(
            cfs[0], lower_control_flow=False)
        graph_def = frozen_fn.graph.as_graph_def(add_shapes=True)

        # run a Grappler's constant folding pass.
        fn_inputs = [
            t for t in frozen_fn.inputs if t.dtype != _dtypes.resource
        ]
        graph_def = _run_graph_optimizations(
            graph_def,
            fn_inputs,
            frozen_fn.outputs,
            config=_get_grappler_config(["constfold", "dependency"]),
            graph=frozen_fn.graph,
        )
        return graph_def