Exemple #1
0
    def test_create_variables_with_same_name(self):
        def f():
            v1 = variables.Variable(0, name='v')
            v2 = variables.Variable(1, name='v')
            return v1, v2

        f_wrapped = wrap_function.wrap_function(f, [])
        self.assertDictEqual(
            {
                'v:0': 0,
                'v_1:0': 1
            },  # assert that variable names are uniquified
            {
                v.name: v.numpy()
                for v in f_wrapped._variable_holder.variables.values()
            })

        # Uniquification should reset in separate calls to wrap_function.
        def f2():
            v1 = variables.Variable(3, name='v')
            v2 = variables.Variable(4, name='v')
            return v1, v2

        f_wrapped_2 = wrap_function.wrap_function(f2, [])
        self.assertDictEqual({
            'v:0': 3,
            'v_1:0': 4
        }, {
            v.name: v.numpy()
            for v in f_wrapped_2._variable_holder.variables.values()
        })
  def test_to_proto(self):
    v1 = resource_variable_ops.ResourceVariable(2.)
    saver = functional_saver.MultiDeviceSaver(
        saveable_object_util.saveable_objects_for_op(v1, "x"))
    prefix = os.path.join(self.get_temp_dir(), "ckpt")

    proto_accumulator = []
    wrapped = wrap_function.wrap_function(
        lambda: proto_accumulator.append(saver.to_proto()), signature=())
    self.assertEqual(1, len(proto_accumulator))
    proto = proto_accumulator[0]
    save = wrapped.prune(
        feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name),
        fetches=wrapped.graph.get_tensor_by_name(proto.save_tensor_name))
    restore = wrapped.prune(
        feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name),
        fetches=wrapped.graph.get_operation_by_name(proto.restore_op_name))
    save_path = save(constant_op.constant(prefix))
    v1.assign(1.)
    restore(constant_op.constant(save_path))
    self.assertEqual(2., self.evaluate(v1))

    v2 = resource_variable_ops.ResourceVariable(3.)
    second_saver = functional_saver.MultiDeviceSaver(
        saveable_object_util.saveable_objects_for_op(v2, "x"))
    second_saver.restore(save_path)
    self.assertEqual(2., self.evaluate(v2))
Exemple #3
0
 def load(self, tags):
   """Creates an object from the MetaGraph identified by `tags`."""
   meta_graph_def = self.get_meta_graph_def_from_tags(tags)
   for node in meta_graph_def.graph_def.node:
     if node.op == "VariableV2":
       raise NotImplementedError(
           "Importing a SavedModel which contains RefVariables. This is not "
           "currently supported. Running tf.enable_resource_variables() "
           "before creating exported variables will fix this.")
   load_graph_returns = [None]
   wrapped = wrap_function.wrap_function(
       functools.partial(self.load_graph, load_graph_returns, meta_graph_def),
       signature=[])
   saver, = load_graph_returns
   self.restore_variables(wrapped, saver)
   with wrapped.graph.as_default():
     init_op = loader_impl.get_init_op(meta_graph_def)
   if init_op is not None:
     # TODO(allenl): Deal with assets
     wrapped.prune(feeds=[],
                   fetches=[wrapped.graph.as_graph_element(init_op)])()
   signature_functions = self._extract_signatures(wrapped, meta_graph_def)
   root = tracking.AutoCheckpointable()
   root.signatures = signature_serialization.create_signature_map(
       signature_functions)
   root.variables = list(wrapped.graph.variables)
   return root
Exemple #4
0
    def testImportedFunctionsRegistered(self):
        if test.is_built_with_gpu_support():
            self.skipTest(
                "Disabling this new test due to errors with cuda and rocm")

        with ops.Graph().as_default() as graph:
            x = array_ops.placeholder(dtypes.variant, shape=[], name='foo')
            ds = dataset_ops.from_variant(x,
                                          structure=(structure.TensorStructure(
                                              dtypes.int32, [])))
            y = ds.reduce(array_ops.zeros([], dtype=dtypes.int32),
                          lambda p, q: p + q)

        graph_def = graph.as_graph_def()

        def fn_to_wrap(a):
            returned_elements = graph_def_importer.import_graph_def(
                graph_def, input_map={x.name: a}, return_elements=[y.name])
            return returned_elements[0]

        wrapped_fn = wrap_function.wrap_function(
            fn_to_wrap, [tensor_spec.TensorSpec((), dtypes.variant)])
        ds = dataset_ops.Dataset.from_tensor_slices([10, 20])
        v = dataset_ops.to_variant(ds)
        self.evaluate(wrapped_fn(v))
  def testNoArguments(self):

    def f():
      return constant_op.constant(1.)

    f_wrapped = wrap_function.wrap_function(f, [])
    self.assertAllEqual(1.0, f_wrapped())
  def load(self, tags):
    """Creates an object from the MetaGraph identified by `tags`."""
    meta_graph_def = self.get_meta_graph_def_from_tags(tags)
    load_graph_returns = [None]
    wrapped = wrap_function.wrap_function(
        functools.partial(self.load_graph, load_graph_returns, meta_graph_def),
        signature=[])
    saver, = load_graph_returns
    self.restore_variables(wrapped, saver)
    with wrapped.graph.as_default():
      init_op = loader_impl.get_init_op(meta_graph_def)
    root = tracking.AutoTrackable()
    if init_op is not None:
      asset_feed_tensors = []
      asset_paths = []
      for tensor_name, value in loader_impl.get_asset_tensors(
          self._export_dir, meta_graph_def).items():
        asset_feed_tensors.append(wrapped.graph.as_graph_element(tensor_name))
        asset_paths.append(tracking.TrackableAsset(value))
      init_fn = wrapped.prune(
          feeds=asset_feed_tensors,
          fetches=[wrapped.graph.as_graph_element(init_op)])
      initializer = _Initializer(init_fn, asset_paths)
      initializer._initialize()  # pylint: disable=protected-access
      root.initializer = initializer
      root.asset_paths = asset_paths
    else:
      root.asset_paths = []
    signature_functions = self._extract_signatures(wrapped, meta_graph_def)

    root.signatures = signature_serialization.create_signature_map(
        signature_functions)
    root.variables = list(wrapped.graph.variables)
    return root
    def testGradientsOfPrune(self):

        v1 = variables.Variable(2.)
        v2_holder = []

        def f(z):
            v2 = variables.Variable(3.)
            v2_holder.append(v2)
            return array_ops.identity(v1 * v2 * z, 'fetch')

        f_wrapped = wrap_function.wrap_function(
            f, [tensor_spec.TensorSpec((), dtype=dtypes.float32)])

        x = constant_op.constant(1.)
        with backprop.GradientTape() as tape:
            tape.watch(x)
            out = f_wrapped(x)
        grads = tape.gradient(out, [x, v1, v2_holder[0]])

        self.assertAllEqual(6.0, out)
        self.assertAllEqual([6.0, 3.0, 2.0], grads)

        pruned = f_wrapped.prune(
            feeds=f_wrapped.inputs,
            fetches=f_wrapped.graph.get_tensor_by_name('fetch:0'))

        x = constant_op.constant(1.)
        with backprop.GradientTape() as tape:
            tape.watch(x)
            out = pruned(x)
        grads = tape.gradient(out, [x, v1, v2_holder[0]])

        self.assertAllEqual(6.0, out)
        self.assertAllEqual([6.0, 3.0, 2.0], grads)
  def test_to_proto(self):
    v1 = resource_variable_ops.ResourceVariable(2.)
    saver = functional_saver.MultiDeviceSaver(
        saveable_object_util.saveable_objects_for_op(v1, "x"))
    prefix = os.path.join(self.get_temp_dir(), "ckpt")

    proto_accumulator = []
    wrapped = wrap_function.wrap_function(
        lambda: proto_accumulator.append(saver.to_proto()), signature=())
    self.assertEqual(1, len(proto_accumulator))
    proto = proto_accumulator[0]
    save = wrapped.prune(
        feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name),
        fetches=wrapped.graph.get_tensor_by_name(proto.save_tensor_name))
    restore = wrapped.prune(
        feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name),
        fetches=wrapped.graph.get_operation_by_name(proto.restore_op_name))
    save_path = save(constant_op.constant(prefix))
    v1.assign(1.)
    restore(constant_op.constant(save_path))
    self.assertEqual(2., self.evaluate(v1))

    v2 = resource_variable_ops.ResourceVariable(3.)
    second_saver = functional_saver.MultiDeviceSaver(
        saveable_object_util.saveable_objects_for_op(v2, "x"))
    second_saver.restore(save_path)
    self.assertEqual(2., self.evaluate(v2))
  def load(self, tags):
    """Creates an object from the MetaGraph identified by `tags`."""
    meta_graph_def = self.get_meta_graph_def_from_tags(tags)
    load_graph_returns = [None]
    wrapped = wrap_function.wrap_function(
        functools.partial(self.load_graph, load_graph_returns, meta_graph_def),
        signature=[])
    saver, = load_graph_returns
    self.restore_variables(wrapped, saver)
    with wrapped.graph.as_default():
      init_op = loader_impl.get_init_op(meta_graph_def)
    root = tracking.AutoTrackable()
    if init_op is not None:
      asset_feed_tensors = []
      asset_paths = []
      for tensor_name, value in loader_impl.get_asset_tensors(
          self._export_dir, meta_graph_def).items():
        asset_feed_tensors.append(wrapped.graph.as_graph_element(tensor_name))
        asset_paths.append(tracking.TrackableAsset(value))
      init_fn = wrapped.prune(
          feeds=asset_feed_tensors,
          fetches=[wrapped.graph.as_graph_element(init_op)])
      initializer = _Initializer(init_fn, asset_paths)
      initializer.initialize()
      root.initializer = initializer
      root.asset_paths = asset_paths
    else:
      root.asset_paths = []
    signature_functions = self._extract_signatures(wrapped, meta_graph_def)

    root.signatures = signature_serialization.create_signature_map(
        signature_functions)
    root.variables = list(wrapped.graph.variables)
    return root
  def testGradientsOfPrune(self):

    v1 = variables.Variable(2.)
    v2_holder = []

    def f(z):
      v2 = variables.Variable(3.)
      v2_holder.append(v2)
      return array_ops.identity(v1 * v2 * z, 'fetch')

    f_wrapped = wrap_function.wrap_function(
        f, [tensor_spec.TensorSpec((), dtype=dtypes.float32)])

    x = constant_op.constant(1.)
    with backprop.GradientTape() as tape:
      tape.watch(x)
      out = f_wrapped(x)
    grads = tape.gradient(out, [x, v1, v2_holder[0]])

    self.assertAllEqual(6.0, out)
    self.assertAllEqual([6.0, 3.0, 2.0], grads)

    pruned = f_wrapped.prune(
        feeds=f_wrapped.inputs,
        fetches=f_wrapped.graph.get_tensor_by_name('fetch:0'))

    x = constant_op.constant(1.)
    with backprop.GradientTape() as tape:
      tape.watch(x)
      out = pruned(x)
    grads = tape.gradient(out, [x, v1, v2_holder[0]])

    self.assertAllEqual(6.0, out)
    self.assertAllEqual([6.0, 3.0, 2.0], grads)
  def testNoArguments(self):

    def f():
      return constant_op.constant(1.)

    f_wrapped = wrap_function.wrap_function(f, [])
    self.assertAllEqual(1.0, f_wrapped())
 def wrap_and_execute(self):
     if context.executing_eagerly():
         wrapped = wrap_function.wrap_function(graph_function, [self])
         # use the wrapped graph function
         wrapped()
     else:
         # use the original function
         graph_function(self)
    def testCollectionsIsolation(self):

        v1 = variables.Variable(2.)
        v2_holder = []

        def f():
            v2 = variables.Variable(3.)
            v2_holder.append(v2)
            ops.add_to_collection(ops.GraphKeys.LOSSES,
                                  v2 * constant_op.constant(3.))
            return array_ops.identity(v1 * v2 * constant_op.constant(1.),
                                      'fetch')

        f_wrapped = wrap_function.wrap_function(f, [])
        self.assertAllEqual(6.0, f_wrapped())
        self.assertEqual(
            len(f_wrapped.graph.get_collection(ops.GraphKeys.LOSSES)), 1)
        f_var_collection = f_wrapped.graph.get_collection(
            ops.GraphKeys.TRAINABLE_VARIABLES)
        self.assertEqual(len(f_var_collection), 1)
        self.assertIs(f_var_collection[0], v2_holder[0])

        v3_holder = []

        def g():
            v3 = variables.Variable(4.)
            v3_holder.append(v3)
            ops.add_to_collection(ops.GraphKeys.LOSSES,
                                  v3 * constant_op.constant(3.))
            return array_ops.identity(v1 * v3 * constant_op.constant(1.),
                                      'fetch')

        g_wrapped = wrap_function.wrap_function(g, [])
        self.assertAllEqual(8.0, g_wrapped())
        self.assertEqual(
            len(g_wrapped.graph.get_collection(ops.GraphKeys.LOSSES)), 1)
        g_var_collection = g_wrapped.graph.get_collection(
            ops.GraphKeys.TRAINABLE_VARIABLES)
        self.assertEqual(len(g_var_collection), 1)
        self.assertIs(g_var_collection[0], v3_holder[0])

        # Both have only one value, and their values aren't equal. So no sharing.
        self.assertNotEqual(
            g_wrapped.graph.get_collection(ops.GraphKeys.LOSSES),
            f_wrapped.graph.get_collection(ops.GraphKeys.LOSSES))
    def load(self, tags):
        """Creates an object from the MetaGraph identified by `tags`."""
        meta_graph_def = self.get_meta_graph_def_from_tags(tags)
        load_graph_returns = [None]
        wrapped = wrap_function.wrap_function(functools.partial(
            self.load_graph, load_graph_returns, meta_graph_def),
                                              signature=[])
        saver, = load_graph_returns
        self.restore_variables(wrapped, saver)
        with wrapped.graph.as_default():
            init_op = loader_impl.get_init_op(
                meta_graph_def
            ) or monitored_session.Scaffold.default_local_init_op()
            # Add a dummy Tensor we know we can fetch to add control dependencies to.
            init_anchor = constant_op.constant(0., name="dummy_fetch")

        root = tracking.AutoTrackable()
        asset_feed_tensors = []
        asset_paths = []
        for tensor_name, value in loader_impl.get_asset_tensors(
                self._export_dir, meta_graph_def).items():
            asset_feed_tensors.append(
                wrapped.graph.as_graph_element(tensor_name))
            asset_paths.append(tracking.TrackableAsset(value))
        init_fn = wrapped.prune(
            feeds=asset_feed_tensors,
            fetches=[init_anchor,
                     wrapped.graph.as_graph_element(init_op)])
        initializer = _Initializer(init_fn, asset_paths)
        # pylint: disable=protected-access
        local_init_op, _ = initializer._initialize()
        # pylint: enable=protected-access
        with ops.init_scope():
            if not context.executing_eagerly():
                ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS,
                                      local_init_op)
                for variable in wrapped.graph.get_collection_ref(
                        ops.GraphKeys.LOCAL_VARIABLES):
                    # pylint: disable=protected-access
                    variable._initializer_op = local_init_op
                    # pylint: enable=protected-access
        root.initializer = initializer
        root.asset_paths = asset_paths
        signature_functions = self._extract_signatures(wrapped, meta_graph_def)

        root.signatures = signature_serialization.create_signature_map(
            signature_functions)
        root.variables = list(wrapped.graph.variables)
        root.tensorflow_version = (
            meta_graph_def.meta_info_def.tensorflow_version)
        root.tensorflow_git_version = (
            meta_graph_def.meta_info_def.tensorflow_git_version)
        root.graph = wrapped.graph
        root.prune = wrapped.prune
        return root
Exemple #15
0
    def test_updates_in_wrap_function(self):
        def my_func():
            layer = normalization.BatchNormalization()
            x = array_ops.ones((10, 1))
            y = layer(x, training=True)
            # Updates should be tracked in a `wrap_function`.
            self.assertLen(layer.updates, 2)
            return y

        wrapped_fn = wrap_function.wrap_function(my_func, [])
        wrapped_fn()
    def testPruneOperations(self):

        v = variables.Variable(0)

        def f():
            v.assign_add(1, name='increment', read_value=False)

        f_wrapped = wrap_function.wrap_function(f, [])
        pruned = f_wrapped.prune(
            feeds=(),
            fetches=(f_wrapped.graph.get_operation_by_name('increment'), ))
        self.assertEqual((None, ), pruned())
        self.assertEqual(1, self.evaluate(v))

        del f, f_wrapped

        def f1():
            v.assign_add(array_ops.placeholder(shape=[],
                                               dtype=dtypes.int32,
                                               name='step'),
                         name='increment',
                         read_value=False)
            return constant_op.constant(1, name='other')

        f_wrapped = wrap_function.wrap_function(f1, [])
        increments = f_wrapped.prune(
            feeds=(f_wrapped.graph.get_tensor_by_name('step:0')),
            fetches=(f_wrapped.graph.get_operation_by_name('increment'),
                     f_wrapped.graph.get_tensor_by_name('other:0')))
        first_output, second_output = increments(constant_op.constant(2))
        self.assertEqual(['Placeholder:0', 'Placeholder_1:0'],
                         [t.name for t in increments.inputs])
        self.assertIs(None, first_output)
        self.assertEqual(1, second_output.numpy())
        self.assertEqual(3, v.numpy())
        does_not_increment = f_wrapped.prune(
            feeds=(f_wrapped.graph.get_tensor_by_name('step:0')),
            fetches=f_wrapped.graph.get_tensor_by_name('other:0'))
        self.assertEqual(1,
                         does_not_increment(constant_op.constant(3)).numpy())
        self.assertEqual(3, v.numpy())
Exemple #17
0
    def testVariableLifting(self):
        save_prefix = os.path.join(self.get_temp_dir(), 'meta_graph_test')

        export_graph = ops.Graph()
        with export_graph.as_default():
            v = variables.Variable(1.)
            array_ops.identity(v + 1., name='output')
            saver = saver_lib.Saver([v])
            with self.test_session() as session:
                session.run(v.initializer)
                saver.save(session, save_prefix)

        def importer():
            saver_lib.import_meta_graph(save_prefix + '.meta')
            return ops.get_default_graph().as_graph_element('output:0')

        wrapped = wrap_function.wrap_function(importer, [])
        lifted_variables = list(wrapped.graph.variables)
        self.assertLen(lifted_variables, 1)
        initializer = wrapped.prune([],
                                    wrapped.graph.as_graph_element(
                                        v.initializer.name))
        self.assertEqual(lifted_variables, list(initializer.graph.variables))
        self.assertEqual(initializer.graph.external_captures,
                         wrapped.graph.external_captures)

        @def_function.function
        def wraps_initializer():
            initializer()

        wraps_initializer()
        self.assertEqual(1., lifted_variables[0].numpy())
        wrapped_initializer_graphdef = (
            wraps_initializer.get_concrete_function().graph.as_graph_def())
        self._assert_single_captured_variable_argument(
            wrapped_initializer_graphdef)

        @def_function.function
        def wraps_wrapped():
            return wrapped()

        # Verify that the original graph also has the correct signature.
        wrapped_wrapped_graphdef = (
            wraps_wrapped.get_concrete_function().graph.as_graph_def())
        self._assert_single_captured_variable_argument(
            wrapped_wrapped_graphdef)
        # Now check that the graph runs wrapped, from eager, and when pruned.
        self.assertAllEqual(wraps_wrapped().numpy(),
                            lifted_variables[0].numpy() + 1.)
        self.assertAllEqual(wrapped().numpy(),
                            lifted_variables[0].numpy() + 1.)
        pruned = wrapped.prune([], wrapped.graph.as_graph_element('output:0'))
        self.assertAllEqual(wrapped().numpy(), pruned().numpy())
Exemple #18
0
def my_function_from_graph_def(graph_def, inputs, outputs, ref_captures):
    def _imports_graph_def():
        importer.import_graph_def(graph_def, name="")

    wrapped_import = wrap_function.wrap_function(_imports_graph_def, [])
    import_graph = wrapped_import.graph
    wrapped_import.graph.reset_captures([
        (tensor, import_graph.get_tensor_by_name(placeholder.name))
        for tensor, placeholder in ref_captures
    ])
    return wrapped_import.prune(
        nest.map_structure(import_graph.as_graph_element, inputs),
        nest.map_structure(import_graph.as_graph_element, outputs))
    def test_updates_in_wrap_function(self):
        with context.eager_mode():
            layer = keras.layers.BatchNormalization()

            def my_func():
                x = array_ops.ones((10, 1))
                return layer(x, training=True)

            wrapped_fn = wrap_function.wrap_function(my_func, [])
            wrapped_fn()

            # Updates should be tracked in a `wrap_function`.
            self.assertLen(layer.updates, 2)
  def load(self, tags):
    """Creates an object from the MetaGraph identified by `tags`."""
    meta_graph_def = self.get_meta_graph_def_from_tags(tags)
    load_graph_returns = [None]
    wrapped = wrap_function.wrap_function(
        functools.partial(self.load_graph, load_graph_returns, meta_graph_def),
        signature=[])
    saver, = load_graph_returns
    self.restore_variables(wrapped, saver)
    with wrapped.graph.as_default():
      init_op = loader_impl.get_init_op(
          meta_graph_def) or monitored_session.Scaffold.default_local_init_op()
      # Add a dummy Tensor we know we can fetch to add control dependencies to.
      init_anchor = constant_op.constant(0., name="dummy_fetch")

    root = tracking.AutoTrackable()
    asset_feed_tensors = []
    asset_paths = []
    for tensor_name, value in loader_impl.get_asset_tensors(
        self._export_dir, meta_graph_def).items():
      asset_feed_tensors.append(wrapped.graph.as_graph_element(tensor_name))
      asset_paths.append(tracking.TrackableAsset(value))
    init_fn = wrapped.prune(
        feeds=asset_feed_tensors,
        fetches=[init_anchor, wrapped.graph.as_graph_element(init_op)])
    initializer = _Initializer(init_fn, asset_paths)
    # pylint: disable=protected-access
    local_init_op, _ = initializer._initialize()
    # pylint: enable=protected-access
    with ops.init_scope():
      if not context.executing_eagerly():
        ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, local_init_op)
        for variable in wrapped.graph.get_collection_ref(
            ops.GraphKeys.LOCAL_VARIABLES):
          # pylint: disable=protected-access
          variable._initializer_op = local_init_op
          # pylint: enable=protected-access
    root.initializer = initializer
    root.asset_paths = asset_paths
    signature_functions = self._extract_signatures(wrapped, meta_graph_def)

    root.signatures = signature_serialization.create_signature_map(
        signature_functions)
    root.variables = list(wrapped.graph.variables)
    root.tensorflow_version = (
        meta_graph_def.meta_info_def.tensorflow_version)
    root.tensorflow_git_version = (
        meta_graph_def.meta_info_def.tensorflow_git_version)
    root.graph = wrapped.graph
    root.prune = wrapped.prune
    return root
  def testCollectionsIsolation(self):

    v1 = variables.Variable(2.)
    v2_holder = []
    def f():
      v2 = variables.Variable(3.)
      v2_holder.append(v2)
      ops.add_to_collection(ops.GraphKeys.LOSSES, v2 * constant_op.constant(3.))
      return array_ops.identity(v1 * v2 * constant_op.constant(1.), 'fetch')

    f_wrapped = wrap_function.wrap_function(f, [])
    self.assertAllEqual(6.0, f_wrapped())
    self.assertEqual(
        len(f_wrapped.graph.get_collection(ops.GraphKeys.LOSSES)), 1)
    f_var_collection = f_wrapped.graph.get_collection(
        ops.GraphKeys.TRAINABLE_VARIABLES)
    self.assertEqual(len(f_var_collection), 1)
    self.assertIs(f_var_collection[0], v2_holder[0])

    v3_holder = []
    def g():
      v3 = variables.Variable(4.)
      v3_holder.append(v3)
      ops.add_to_collection(ops.GraphKeys.LOSSES, v3 * constant_op.constant(3.))
      return array_ops.identity(v1 * v3 * constant_op.constant(1.), 'fetch')

    g_wrapped = wrap_function.wrap_function(g, [])
    self.assertAllEqual(8.0, g_wrapped())
    self.assertEqual(
        len(g_wrapped.graph.get_collection(ops.GraphKeys.LOSSES)), 1)
    g_var_collection = g_wrapped.graph.get_collection(
        ops.GraphKeys.TRAINABLE_VARIABLES)
    self.assertEqual(len(g_var_collection), 1)
    self.assertIs(g_var_collection[0], v3_holder[0])

    # Both have only one value, and their values aren't equal. So no sharing.
    self.assertNotEqual(g_wrapped.graph.get_collection(ops.GraphKeys.LOSSES),
                        f_wrapped.graph.get_collection(ops.GraphKeys.LOSSES))
  def testPruneOperations(self):

    v = variables.Variable(0)

    def f():
      v.assign_add(1, name='increment', read_value=False)

    f_wrapped = wrap_function.wrap_function(f, [])
    pruned = f_wrapped.prune(
        feeds=(),
        fetches=(f_wrapped.graph.get_operation_by_name('increment'),))
    self.assertEqual((None,), pruned())
    self.assertEqual(1, self.evaluate(v))

    del f, f_wrapped

    def f1():
      v.assign_add(
          array_ops.placeholder(shape=[], dtype=dtypes.int32, name='step'),
          name='increment', read_value=False)
      return constant_op.constant(1, name='other')

    f_wrapped = wrap_function.wrap_function(f1, [])
    increments = f_wrapped.prune(
        feeds=(f_wrapped.graph.get_tensor_by_name('step:0')),
        fetches=(f_wrapped.graph.get_operation_by_name('increment'),
                 f_wrapped.graph.get_tensor_by_name('other:0')))
    first_output, second_output = increments(constant_op.constant(2))
    self.assertEqual(['step:0', 'increment/resource:0'],
                     [t.name for t in increments.inputs])
    self.assertIs(None, first_output)
    self.assertEqual(1, second_output.numpy())
    self.assertEqual(3, v.numpy())
    does_not_increment = f_wrapped.prune(
        feeds=(f_wrapped.graph.get_tensor_by_name('step:0')),
        fetches=f_wrapped.graph.get_tensor_by_name('other:0'))
    self.assertEqual(1, does_not_increment(constant_op.constant(3)).numpy())
    self.assertEqual(3, v.numpy())
    def testCaptures(self):

        v1 = variables.Variable(2.)

        def f():
            v2 = variables.Variable(3.)
            return array_ops.identity(v1 * v2 * constant_op.constant(1.),
                                      'fetch')

        f_wrapped = wrap_function.wrap_function(f, [])
        self.assertAllEqual(6.0, f_wrapped())
        pruned = f_wrapped.prune(
            feeds=(), fetches=f_wrapped.graph.get_tensor_by_name('fetch:0'))
        self.assertAllEqual(6.0, pruned())
  def test_create_variables_with_same_name(self):
    def f():
      v1 = variables.Variable(0, name='v')
      v2 = variables.Variable(1, name='v')
      return v1, v2

    f_wrapped = wrap_function.wrap_function(f, [])
    self.assertDictEqual(
        {'v:0': 0, 'v_1:0': 1},  # assert that variable names are uniquified
        {v.name: v.numpy()
         for v in f_wrapped._variable_holder.variables.values()})

    # Uniquification should reset in separate calls to wrap_function.
    def f2():
      v1 = variables.Variable(3, name='v')
      v2 = variables.Variable(4, name='v')
      return v1, v2

    f_wrapped_2 = wrap_function.wrap_function(f2, [])
    self.assertDictEqual(
        {'v:0': 3, 'v_1:0': 4},
        {v.name: v.numpy()
         for v in f_wrapped_2._variable_holder.variables.values()})
    def testDocString(self):
        def f(x, do_add):
            v = variables.Variable(5.0)
            if do_add:
                op = v.assign_add(x)
            else:
                op = v.assign_sub(x)
            with ops.control_dependencies([op]):
                return v.read_value()

        f_add = wrap_function.wrap_function(
            f, [tensor_spec.TensorSpec((), dtypes.float32), True])

        self.assertAllEqual(f_add(1.0), 6.0)
        self.assertAllEqual(f_add(1.0), 7.0)

        # Can call tf.compat.v1.wrap_function again to get a new trace, a new set
        # of variables, and possibly different non-template arguments.
        f_sub = wrap_function.wrap_function(
            f, [tensor_spec.TensorSpec((), dtypes.float32), False])

        self.assertAllEqual(f_sub(1.0), 4.0)
        self.assertAllEqual(f_sub(1.0), 3.0)
  def testVariableLifting(self):
    save_prefix = os.path.join(self.get_temp_dir(), 'meta_graph_test')

    export_graph = ops.Graph()
    with export_graph.as_default():
      v = variables.Variable(1.)
      array_ops.identity(v + 1., name='output')
      saver = saver_lib.Saver([v])
      with self.test_session() as session:
        session.run(v.initializer)
        saver.save(session, save_prefix)

    def importer():
      saver_lib.import_meta_graph(save_prefix + '.meta')
      return ops.get_default_graph().as_graph_element('output:0')

    wrapped = wrap_function.wrap_function(importer, [])
    lifted_variables = list(wrapped.graph.variables)
    self.assertLen(lifted_variables, 1)
    initializer = wrapped.prune(
        [], wrapped.graph.as_graph_element(v.initializer.name))
    self.assertEqual(lifted_variables, list(initializer.graph.variables))
    self.assertEqual(initializer.graph.external_captures,
                     wrapped.graph.external_captures)

    @def_function.function
    def wraps_initializer():
      initializer()

    wraps_initializer()
    self.assertEqual(1., lifted_variables[0].numpy())
    wrapped_initializer_graphdef = (
        wraps_initializer.get_concrete_function().graph.as_graph_def())
    self._assert_single_captured_variable_argument(wrapped_initializer_graphdef)

    @def_function.function
    def wraps_wrapped():
      return wrapped()

    # Verify that the original graph also has the correct signature.
    wrapped_wrapped_graphdef = (
        wraps_wrapped.get_concrete_function().graph.as_graph_def())
    self._assert_single_captured_variable_argument(wrapped_wrapped_graphdef)
    # Now check that the graph runs wrapped, from eager, and when pruned.
    self.assertAllEqual(wraps_wrapped().numpy(),
                        lifted_variables[0].numpy() + 1.)
    self.assertAllEqual(wrapped().numpy(), lifted_variables[0].numpy() + 1.)
    pruned = wrapped.prune([], wrapped.graph.as_graph_element('output:0'))
    self.assertAllEqual(wrapped().numpy(), pruned().numpy())
  def testDocString(self):

    def f(x, do_add):
      v = variables.Variable(5.0)
      if do_add:
        op = v.assign_add(x)
      else:
        op = v.assign_sub(x)
      with ops.control_dependencies([op]):
        return v.read_value()

    f_add = wrap_function.wrap_function(
        f, [tensor_spec.TensorSpec((), dtypes.float32), True])

    self.assertAllEqual(f_add(1.0), 6.0)
    self.assertAllEqual(f_add(1.0), 7.0)

    # Can call tf.compat.v1.wrap_function again to get a new trace, a new set
    # of variables, and possibly different non-template arguments.
    f_sub = wrap_function.wrap_function(
        f, [tensor_spec.TensorSpec((), dtypes.float32), False])

    self.assertAllEqual(f_sub(1.0), 4.0)
    self.assertAllEqual(f_sub(1.0), 3.0)
  def testPruneStatefulOpsFromWrappedFunc(self):

    v0 = variables.Variable(0)
    v1 = variables.Variable(0)

    # When we wrap a function, we expect it to be executed with 'tf.Graph`
    # rules: it's allowed to prune all ops that are not in transitive fanin of
    # the fetches.
    def f(x):
      v0.assign_add(1, name='increment_v0')
      v1.assign_add(1, name='increment_v1')
      return x

    f_wrapped = wrap_function.wrap_function(f, [1])

    self.assertEqual(1, f_wrapped().numpy())
    self.assertEqual(0, v0.numpy())
    self.assertEqual(0, v1.numpy())

    f_wrapped_with_name = wrap_function.wrap_function(f, [2], name='func')

    self.assertEqual(2, f_wrapped_with_name().numpy())
    self.assertEqual(0, v0.numpy())
    self.assertEqual(0, v1.numpy())
Exemple #29
0
    def testPruneStatefulOpsFromWrappedFunc(self):

        v0 = variables.Variable(0)
        v1 = variables.Variable(0)

        # When we wrap a function, we expect it to be executed with 'tf.Graph`
        # rules: it's allowed to prune all ops that are not in transitive fanin of
        # the fetches.
        def f(x):
            v0.assign_add(1, name='increment_v0')
            v1.assign_add(1, name='increment_v1')
            return x

        f_wrapped = wrap_function.wrap_function(f, [1])

        self.assertEqual(1, f_wrapped().numpy())
        self.assertEqual(0, v0.numpy())
        self.assertEqual(0, v1.numpy())

        f_wrapped_with_name = wrap_function.wrap_function(f, [2], name='func')

        self.assertEqual(2, f_wrapped_with_name().numpy())
        self.assertEqual(0, v0.numpy())
        self.assertEqual(0, v1.numpy())
  def testPrune(self):

    x_in = []
    x_out = []

    def f(x, y):
      x_in.append(x)
      xx = x * x
      x_out.append(xx)
      return xx, 2 * y*y

    f_wrapped = wrap_function.wrap_function(
        f, [tensor_spec.TensorSpec((), dtypes.float32)] * 2)

    f_pruned = f_wrapped.prune(x_in[0], [x_out[0]])
    self.assertAllEqual(f_pruned(ops.convert_to_tensor(2.0)), [4.0])
  def testPrune(self):

    x_in = []
    x_out = []

    def f(x, y):
      x_in.append(x)
      xx = x * x
      x_out.append(xx)
      return xx, 2 * y*y

    f_wrapped = wrap_function.wrap_function(
        f, [tensor_spec.TensorSpec((), dtypes.float32)] * 2)

    f_pruned = f_wrapped.prune(x_in[0], [x_out[0]])
    self.assertAllEqual(f_pruned(ops.convert_to_tensor(2.0)), [4.0])
Exemple #32
0
    def test_operation_returned(self):

        v = variables.Variable(0)

        def f():
            v.assign(1, read_value=False, name='assign_to_v')

        f_wrapped = wrap_function.wrap_function(f, [])
        operation_to_fetch = f_wrapped.graph.get_operation_by_name(
            'assign_to_v')
        f_pruned = f_wrapped.prune([], operation_to_fetch)
        self.assertEqual(
            ['assign_to_v'],
            [operation.name for operation in f_pruned.graph.control_outputs])
        self.assertEqual(0, v.numpy())
        f_pruned()
        self.assertEqual(1, v.numpy())
  def test_operation_returned(self):

    v = variables.Variable(0)

    def f():
      v.assign(1, read_value=False, name='assign_to_v')

    f_wrapped = wrap_function.wrap_function(f, [])
    operation_to_fetch = f_wrapped.graph.get_operation_by_name('assign_to_v')
    f_pruned = f_wrapped.prune(
        [], operation_to_fetch)
    self.assertEqual(
        ['assign_to_v'],
        [operation.name for operation in f_pruned.graph.control_outputs])
    self.assertEqual(0, v.numpy())
    f_pruned()
    self.assertEqual(1, v.numpy())
Exemple #34
0
 def load(self, tags):
     """Creates an object from the MetaGraph identified by `tags`."""
     meta_graph_def = self.get_meta_graph_def_from_tags(tags)
     load_graph_returns = [None]
     wrapped = wrap_function.wrap_function(functools.partial(
         self.load_graph, load_graph_returns, meta_graph_def),
                                           signature=[])
     saver, = load_graph_returns
     self.restore_variables(wrapped, saver)
     with wrapped.graph.as_default():
         init_op = loader_impl.get_init_op(meta_graph_def)
     if init_op is not None:
         # TODO(allenl): Deal with assets
         wrapped.prune(feeds=[],
                       fetches=[wrapped.graph.as_graph_element(init_op)])()
     signature_functions = self._extract_signatures(wrapped, meta_graph_def)
     root = tracking.AutoCheckpointable()
     root.signatures = signature_serialization.create_signature_map(
         signature_functions)
     root.variables = list(wrapped.graph.variables)
     return root
  def testPruneCaptures(self):

    v1 = variables.Variable(2.)

    def f():
      v2 = variables.Variable(3.)
      return array_ops.identity(v1 * v2 * constant_op.constant(1.), 'fetch')

    f_wrapped = wrap_function.wrap_function(f, [])
    self.assertAllEqual(6.0, f_wrapped())

    # Test pruning directly on the inputs
    pruned = f_wrapped.prune(
        feeds=f_wrapped.inputs,
        fetches=f_wrapped.graph.get_tensor_by_name('fetch:0'))
    self.assertAllEqual(6.0, pruned())

    # Test pruning with no inputs
    pruned = f_wrapped.prune(
        feeds=(),
        fetches=f_wrapped.graph.get_tensor_by_name('fetch:0'))
    self.assertAllEqual(6.0, pruned())
Exemple #36
0
    def testImportedFunctionsRegistered(self):
        if test_util.is_gpu_available():
            self.skipTest('not a GPU test')
        with ops.Graph().as_default() as graph:
            x = array_ops.placeholder(dtypes.variant, shape=[], name='foo')
            ds = dataset_ops.from_variant(x,
                                          structure=(tensor_spec.TensorSpec(
                                              [], dtypes.int32)))
            y = ds.reduce(array_ops.zeros([], dtype=dtypes.int32),
                          lambda p, q: p + q)

        graph_def = graph.as_graph_def()

        def fn_to_wrap(a):
            returned_elements = graph_def_importer.import_graph_def(
                graph_def, input_map={x.name: a}, return_elements=[y.name])
            return returned_elements[0]

        wrapped_fn = wrap_function.wrap_function(
            fn_to_wrap, [tensor_spec.TensorSpec((), dtypes.variant)])
        ds = dataset_ops.Dataset.from_tensor_slices([10, 20])
        v = dataset_ops.to_variant(ds)
        self.evaluate(wrapped_fn(v))
    def testWrapFuncDatasetDevice(self, device_type, dataset_reduce_fn):

        devices = config.list_logical_devices(device_type=device_type)
        if not devices:
            self.skipTest(
                'Skip when {} is not detected by TF'.format(device_type))

        @def_function.function
        def comp():
            return dataset_reduce_fn(dataset_ops.Dataset.range(10))

        graph = comp.get_concrete_function().graph

        def function_to_wrap():
            with ops.device(devices[0].name):
                return graph_def_importer.import_graph_def(
                    graph.as_graph_def())

        with ops.device(devices[0].name):
            wrapped_noarg_fn = wrap_function.wrap_function(function_to_wrap,
                                                           signature=[])

        wrapped_noarg_fn()
Exemple #38
0
    def testPruneRagged(self):

        x_in = []
        x_out = []

        def f(x, y):
            x_in.append(x)
            xx = x * x
            x_out.append(xx)
            return xx, y * y

        x_spec = ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32)
        y_spec = tensor_spec.TensorSpec((), dtypes.float32)

        f_wrapped = wrap_function.wrap_function(f, [x_spec, y_spec])

        f_pruned = f_wrapped.prune(x_in[0], x_out[0])
        rt = ragged_factory_ops.constant([[1.0, 2.0], [3.0]])
        expected = ragged_factory_ops.constant_value([[1.0, 4.0], [9.0]])

        # Note: when we call f_pruned, we must pass the RaggedTensor in using
        # its components, since that's the current convention for how concrete
        # functions handle structured inputs.
        self.assertAllEqual(f_pruned(rt.values, rt.row_splits), expected)
Exemple #39
0
    def load(self, tags):
        """Creates an object from the MetaGraph identified by `tags`."""
        meta_graph_def = self.get_meta_graph_def_from_tags(tags)
        load_shared_name_suffix = "_load_{}".format(ops.uid())
        functions = function_deserialization.load_function_def_library(
            meta_graph_def.graph_def.library,
            load_shared_name_suffix=load_shared_name_suffix)
        # Replace existing functions in the MetaGraphDef with renamed functions so
        # we don't have duplicates or name collisions.
        meta_graph_def.graph_def.library.Clear()
        for function in functions.values():
            meta_graph_def.graph_def.library.function.add().CopyFrom(
                function.function_def)
        # We've renamed functions and shared names. We need the same operation on
        # the GraphDef itself for consistency.
        for node_def in meta_graph_def.graph_def.node:
            function_deserialization.fix_node_def(
                node_def,
                functions,
                load_shared_name_suffix,
                debug_name="MetaGraph import")

        load_graph_returns = [None]
        wrapped = wrap_function.wrap_function(functools.partial(
            self.load_graph, load_graph_returns, meta_graph_def),
                                              signature=[])
        saver, = load_graph_returns
        self.restore_variables(wrapped, saver)
        with wrapped.graph.as_default():
            init_op = loader_impl.get_init_op(
                meta_graph_def
            ) or monitored_session.Scaffold.default_local_init_op()
            # Add a dummy Tensor we know we can fetch to add control dependencies to.
            init_anchor = constant_op.constant(0., name="dummy_fetch")

        root = tracking.AutoTrackable()
        asset_feed_tensors = []
        asset_paths = []
        for tensor_name, value in loader_impl.get_asset_tensors(
                self._export_dir, meta_graph_def).items():
            asset_feed_tensors.append(
                wrapped.graph.as_graph_element(tensor_name))
            asset_paths.append(tracking.Asset(value))
        init_fn = wrapped.prune(
            feeds=asset_feed_tensors,
            fetches=[init_anchor,
                     wrapped.graph.as_graph_element(init_op)])
        initializer = _Initializer(init_fn, asset_paths)
        # pylint: disable=protected-access
        local_init_op, _ = initializer._initialize()
        # pylint: enable=protected-access
        with ops.init_scope():
            if not context.executing_eagerly():
                ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS,
                                      local_init_op)
                for variable in wrapped.graph.get_collection_ref(
                        ops.GraphKeys.LOCAL_VARIABLES):
                    # pylint: disable=protected-access
                    variable._initializer_op = local_init_op
                    # pylint: enable=protected-access
        root.initializer = initializer
        root.asset_paths = asset_paths
        signature_functions = self._extract_signatures(wrapped, meta_graph_def)

        root.signatures = signature_serialization.create_signature_map(
            signature_functions)
        root.variables = list(wrapped.graph.variables)
        root.tensorflow_version = (
            meta_graph_def.meta_info_def.tensorflow_version)
        root.tensorflow_git_version = (
            meta_graph_def.meta_info_def.tensorflow_git_version)
        root.graph = wrapped.graph
        root.prune = wrapped.prune
        return root