Пример #1
0
 def testNestedLists(self):
     a = tracking.Checkpointable()
     a.l = []
     b = tracking.Checkpointable()
     a.l.append([b])
     c = tracking.Checkpointable()
     a.l[0].append(c)
     a_deps = util.list_objects(a)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     a.l[0].append(1)
     d = tracking.Checkpointable()
     a.l[0].append(d)
     a_deps = util.list_objects(a)
     self.assertIn(d, a_deps)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     self.assertNotIn(1, a_deps)
     e = tracking.Checkpointable()
     f = tracking.Checkpointable()
     a.l1 = [[], [e]]
     a.l1[0].append(f)
     a_deps = util.list_objects(a)
     self.assertIn(e, a_deps)
     self.assertIn(f, a_deps)
     checkpoint = util.Checkpoint(a=a)
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
     a.l[0].append(data_structures.NoDependency([]))
     a.l[0][-1].append(5)
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
     # Dirtying the inner list means the root object is unsaveable.
     a.l[0][1] = 2
     with self.assertRaisesRegexp(ValueError,
                                  "A list element was replaced"):
         checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
Пример #2
0
 def testNestedLists(self):
   a = tracking.Checkpointable()
   a.l = []
   b = tracking.Checkpointable()
   a.l.append([b])
   c = tracking.Checkpointable()
   a.l[0].append(c)
   a_deps = util.list_objects(a)
   self.assertIn(b, a_deps)
   self.assertIn(c, a_deps)
   a.l[0].append(1)
   d = tracking.Checkpointable()
   a.l[0].append(d)
   a_deps = util.list_objects(a)
   self.assertIn(d, a_deps)
   self.assertIn(b, a_deps)
   self.assertIn(c, a_deps)
   self.assertNotIn(1, a_deps)
   e = tracking.Checkpointable()
   f = tracking.Checkpointable()
   a.l1 = [[], [e]]
   a.l1[0].append(f)
   a_deps = util.list_objects(a)
   self.assertIn(e, a_deps)
   self.assertIn(f, a_deps)
   checkpoint = util.Checkpoint(a=a)
   checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
   a.l[0].append(data_structures.NoDependency([]))
   a.l[0][-1].append(5)
   checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
   # Dirtying the inner list means the root object is unsaveable.
   a.l[0][1] = 2
   with self.assertRaisesRegexp(ValueError, "A list element was replaced"):
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
 def get_non_optimizer_objects(m, g):
   """Gather set of model and optimizer checkpointable objects."""
   # Set default graph because optimizer.variables() returns optimizer
   # variables defined in the default graph.
   with g.as_default():
     all_objects = set(checkpointable_utils.list_objects(m))
     optimizer_and_variables = set()
     for obj in all_objects:
       if isinstance(obj, optimizers.TFOptimizer):
         optimizer_and_variables.update(checkpointable_utils.list_objects(obj))
         optimizer_and_variables.update(set(obj.optimizer.variables()))
     return all_objects - optimizer_and_variables
Пример #4
0
 def testAddVariableOverwrite(self):
   root = base.Checkpointable()
   a = root._add_variable_with_custom_getter(
       name="v", shape=[], getter=variable_scope.get_variable)
   self.assertEqual([root, a], util.list_objects(root))
   with ops.Graph().as_default():
     b = root._add_variable_with_custom_getter(
         name="v", shape=[], overwrite=True,
         getter=variable_scope.get_variable)
     self.assertEqual([root, b], util.list_objects(root))
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(
         ValueError, "already declared as a dependency"):
       root._add_variable_with_custom_getter(
           name="v", shape=[], overwrite=False,
           getter=variable_scope.get_variable)
Пример #5
0
 def testAddVariableOverwrite(self):
   root = base.CheckpointableBase()
   a = root._add_variable_with_custom_getter(
       name="v", shape=[], getter=variable_scope.get_variable)
   self.assertEqual([root, a], util.list_objects(root))
   with ops.Graph().as_default():
     b = root._add_variable_with_custom_getter(
         name="v", shape=[], overwrite=True,
         getter=variable_scope.get_variable)
     self.assertEqual([root, b], util.list_objects(root))
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(
         ValueError, "already declared as a dependency"):
       root._add_variable_with_custom_getter(
           name="v", shape=[], overwrite=False,
           getter=variable_scope.get_variable)
Пример #6
0
 def testDictWrapperNoDependency(self):
   a = tracking.Checkpointable()
   a.d = data_structures.NoDependency({})
   a.d[1] = [3]
   self.assertEqual([a], util.list_objects(a))
   model = training.Model()
   model.sub = a
   save_path = os.path.join(self.get_temp_dir(), "ckpt")
   model.save_weights(save_path)
   model.load_weights(save_path)
Пример #7
0
 def testNonStringKeyNotCheckpointableValue(self):
   a = tracking.Checkpointable()
   a.d = {}
   a.d["a"] = [3]
   a.d[1] = data_structures.NoDependency([3])
   self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
   model = training.Model()
   model.sub = a
   save_path = os.path.join(self.get_temp_dir(), "ckpt")
   model.save_weights(save_path)
   model.load_weights(save_path)
    def test_checkpointable_save_restore(self):
        def _templated():
            v = variable_scope.get_variable(
                "v",
                shape=[1],
                initializer=init_ops.zeros_initializer(),
                use_resource=True)
            v2 = variable_scope.get_variable(
                "v2",
                shape=[1],
                initializer=init_ops.zeros_initializer(),
                use_resource=True)
            manual = _ManualScope()
            return v, v + 1., v2, manual, manual()

        save_template = template.make_template("s1", _templated)
        v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
        six.assertCountEqual(
            self,
            [v1_save, v2_save, manual_scope, manual_scope_v, save_template],
            checkpointable_utils.list_objects(save_template))
        manual_dep, = manual_scope._checkpoint_dependencies
        self.assertEqual("in_manual_scope", manual_dep.name)
        self.assertIs(manual_scope_v, manual_dep.ref)
        optimizer = adam.AdamOptimizer(0.0)
        save_root = checkpointable_utils.Checkpoint(my_template=save_template,
                                                    optimizer=optimizer)
        optimizer.minimize(v1_save.read_value)
        self.evaluate([v.initializer for v in save_template.variables])
        self.evaluate([v.initializer for v in optimizer.variables()])
        self.evaluate(v1_save.assign([12.]))
        self.evaluate(v2_save.assign([14.]))
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        save_path = save_root.save(checkpoint_prefix)

        load_template = template.make_template("s2", _templated)
        load_optimizer = adam.AdamOptimizer(0.0)
        load_root = checkpointable_utils.Checkpoint(my_template=load_template,
                                                    optimizer=load_optimizer)
        status = load_root.restore(save_path)
        var, var_plus_one, var2, _, _ = load_template()
        load_optimizer.minimize(var.read_value)
        self.assertEqual(3, len(load_template._checkpoint_dependencies))
        self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
        self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
        self.assertEqual("ManualScope",
                         load_template._checkpoint_dependencies[2].name)
        status.assert_consumed().run_restore_ops()
        self.assertAllEqual([12.], self.evaluate(var))
        self.assertAllEqual([13.], self.evaluate(var_plus_one))
        self.assertAllEqual([14.], self.evaluate(var2))
Пример #9
0
 def testListBasic(self):
     a = tracking.Checkpointable()
     b = tracking.Checkpointable()
     a.l = [b]
     c = tracking.Checkpointable()
     a.l.append(c)
     a_deps = util.list_objects(a)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     direct_a_dep, = a._checkpoint_dependencies
     self.assertEqual("l", direct_a_dep.name)
     self.assertIn(b, direct_a_dep.ref)
     self.assertIn(c, direct_a_dep.ref)
Пример #10
0
 def testShallowCopyCheckpointable(self):
   original = tracking.Checkpointable()
   original_sub = tracking.Checkpointable()
   original.a = [[1.]]
   original.b = {"a": original_sub}
   shallow_copied = copy.copy(original)
   self.assertIs(original_sub, shallow_copied.b["a"])
   self.assertIsNot(original, shallow_copied)
   self.assertEqual([[1.]], shallow_copied.a)
   shallow_deps = util.list_objects(shallow_copied)
   self.assertIn(shallow_copied.a, shallow_deps)
   self.assertIn(shallow_copied.b, shallow_deps)
   self.assertIn(shallow_copied.b["a"], shallow_deps)
Пример #11
0
 def testListBasic(self):
   a = tracking.Checkpointable()
   b = tracking.Checkpointable()
   a.l = [b]
   c = tracking.Checkpointable()
   a.l.append(c)
   a_deps = util.list_objects(a)
   self.assertIn(b, a_deps)
   self.assertIn(c, a_deps)
   direct_a_dep, = a._checkpoint_dependencies
   self.assertEqual("l", direct_a_dep.name)
   self.assertIn(b, direct_a_dep.ref)
   self.assertIn(c, direct_a_dep.ref)
  def test_checkpointable_save_restore(self):

    def _templated():
      v = variable_scope.get_variable(
          "v", shape=[1], initializer=init_ops.zeros_initializer(),
          use_resource=True)
      v2 = variable_scope.get_variable(
          "v2", shape=[1], initializer=init_ops.zeros_initializer(),
          use_resource=True)
      manual = _ManualScope()
      return v, v + 1., v2, manual, manual()

    save_template = template.make_template("s1", _templated)
    v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
    six.assertCountEqual(
        self,
        [v1_save, v2_save, manual_scope, manual_scope_v, save_template],
        checkpointable_utils.list_objects(save_template))
    manual_dep, = manual_scope._checkpoint_dependencies
    self.assertEqual("in_manual_scope", manual_dep.name)
    self.assertIs(manual_scope_v, manual_dep.ref)
    optimizer = adam.AdamOptimizer(0.0)
    save_root = checkpointable_utils.Checkpoint(
        my_template=save_template, optimizer=optimizer)
    optimizer.minimize(v1_save.read_value)
    self.evaluate([v.initializer for v in save_template.variables])
    self.evaluate([v.initializer for v in optimizer.variables()])
    self.evaluate(v1_save.assign([12.]))
    self.evaluate(v2_save.assign([14.]))
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    save_path = save_root.save(checkpoint_prefix)

    load_template = template.make_template("s2", _templated)
    load_optimizer = adam.AdamOptimizer(0.0)
    load_root = checkpointable_utils.Checkpoint(
        my_template=load_template, optimizer=load_optimizer)
    status = load_root.restore(save_path)
    var, var_plus_one, var2, _, _ = load_template()
    load_optimizer.minimize(var.read_value)
    self.assertEqual(3, len(load_template._checkpoint_dependencies))
    self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
    self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
    self.assertEqual("ManualScope",
                     load_template._checkpoint_dependencies[2].name)
    status.assert_consumed().run_restore_ops()
    self.assertAllEqual([12.], self.evaluate(var))
    self.assertAllEqual([13.], self.evaluate(var_plus_one))
    self.assertAllEqual([14.], self.evaluate(var2))
Пример #13
0
 def testDeepCopyCheckpointable(self):
   original = tracking.Checkpointable()
   original_sub = tracking.Checkpointable()
   original.a = [[1.]]
   original.b = {"a": original_sub}
   deep_copied = copy.deepcopy(original)
   self.assertIsNot(original, deep_copied)
   self.assertIsNot(original_sub, deep_copied.b["a"])
   self.assertEqual([[1.]], deep_copied.a)
   self.assertIsInstance(deep_copied.b["a"], tracking.Checkpointable)
   deps = util.list_objects(deep_copied)
   self.assertIn(deep_copied.a, deps)
   self.assertIn(deep_copied.b, deps)
   self.assertIn(deep_copied.b["a"], deps)
   self.assertNotIn(original_sub, deps)
Пример #14
0
    def test_checkpointable_dependencies(self):
        rnn = keras.layers.SimpleRNN
        x = np.random.random((2, 2, 2))
        y = np.random.random((2, 2))
        model = keras.models.Sequential()
        model.add(rnn(2))
        model.compile(optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.001),
                      loss='mse')
        model.fit(x, y, epochs=1, batch_size=1)

        # check whether the model variables are present in the
        # checkpointable list of objects
        checkpointed_objects = set(checkpointable_util.list_objects(model))
        for v in model.variables:
            self.assertIn(v, checkpointed_objects)
Пример #15
0
  def test_checkpointable_dependencies(self):
    rnn = keras.layers.SimpleRNN
    with self.test_session():
      x = np.random.random((2, 2, 2))
      y = np.random.random((2, 2))
      model = keras.models.Sequential()
      model.add(rnn(2))
      model.compile(optimizer='rmsprop', loss='mse')
      model.fit(x, y, epochs=1, batch_size=1)

      # check whether the model variables are present in the
      # checkpointable list of objects
      checkpointed_objects = set(checkpointable_util.list_objects(model))
      for v in model.variables:
        self.assertIn(v, checkpointed_objects)
Пример #16
0
    def test_timedistributed_dense(self):
        model = keras.models.Sequential()
        model.add(
            keras.layers.TimeDistributed(keras.layers.Dense(2),
                                         input_shape=(3, 4)))
        model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse')
        model.fit(np.random.random((10, 3, 4)),
                  np.random.random((10, 3, 2)),
                  epochs=1,
                  batch_size=10)

        # test config
        model.get_config()

        checkpointed_objects = set(checkpointable_util.list_objects(model))
        for v in model.variables:
            self.assertIn(v, checkpointed_objects)
Пример #17
0
  def test_timedistributed_dense(self):
    model = keras.models.Sequential()
    model.add(
        keras.layers.TimeDistributed(
            keras.layers.Dense(2), input_shape=(3, 4)))
    model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse')
    model.fit(
        np.random.random((10, 3, 4)),
        np.random.random((10, 3, 2)),
        epochs=1,
        batch_size=10)

    # test config
    model.get_config()

    checkpointed_objects = set(checkpointable_util.list_objects(model))
    for v in model.variables:
      self.assertIn(v, checkpointed_objects)
Пример #18
0
    def testNoDependency(self):
        root = tracking.Checkpointable()
        hasdep = tracking.Checkpointable()
        root.hasdep = hasdep
        nodep = tracking.Checkpointable()
        root.nodep = data_structures.NoDependency(nodep)
        self.assertEqual(1, len(root._checkpoint_dependencies))
        self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
        self.assertIs(root.hasdep, hasdep)
        self.assertIs(root.nodep, nodep)

        class NoDependencyModel(training.Model):
            @base.no_automatic_dependency_tracking
            def __init__(self):
                super(NoDependencyModel, self).__init__()
                self.a = []
                self.b = tracking.Checkpointable()

        nodeps = NoDependencyModel()
        self.assertEqual([nodeps], util.list_objects(nodeps))
Пример #19
0
 def testNonAppendNotCheckpointable(self):
   # Non-append mutations (deleting or overwriting values) are OK when the
   # values aren't tracked.
   a = tracking.Checkpointable()
   a.d = {}
   a.d["a"] = [3]
   a.d[1] = 3
   a.d[1] = 2
   self.assertEqual(2, a.d[1])
   del a.d[1]
   a.d[2] = data_structures.NoDependency(tracking.Checkpointable())
   second = tracking.Checkpointable()
   a.d[2] = data_structures.NoDependency(second)
   self.assertIs(second, a.d[2])
   self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
   model = training.Model()
   model.sub = a
   save_path = os.path.join(self.get_temp_dir(), "ckpt")
   model.save_weights(save_path)
   model.load_weights(save_path)
Пример #20
0
  def testNoDependency(self):
    root = tracking.Checkpointable()
    hasdep = tracking.Checkpointable()
    root.hasdep = hasdep
    nodep = tracking.Checkpointable()
    root.nodep = data_structures.NoDependency(nodep)
    self.assertEqual(1, len(root._checkpoint_dependencies))
    self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
    self.assertIs(root.hasdep, hasdep)
    self.assertIs(root.nodep, nodep)

    class NoDependencyModel(training.Model):

      @base.no_automatic_dependency_tracking
      def __init__(self):
        super(NoDependencyModel, self).__init__()
        self.a = []
        self.b = tracking.Checkpointable()

    nodeps = NoDependencyModel()
    self.assertEqual([nodeps], util.list_objects(nodeps))
Пример #21
0
 def testDictionariesBasic(self):
   a = training.Model()
   b = training.Model()
   a.attribute = {"b": b}
   c = training.Model()
   a.attribute["c"] = []
   a.attribute["c"].append(c)
   a_deps = util.list_objects(a)
   self.assertIn(b, a_deps)
   self.assertIn(c, a_deps)
   self.assertIs(b, a.attribute["b"])
   six.assertCountEqual(
       self,
       ["b", "c"],
       [dep.name for dep in a.attribute._checkpoint_dependencies])
   self.assertEqual([b, c], a.layers)
   self.assertEqual([b, c], a.attribute.layers)
   self.assertEqual([c], a.attribute["c"].layers)
   checkpoint = util.Checkpoint(a=a)
   save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
   checkpoint.restore(save_path).assert_consumed()
Пример #22
0
  def testDictDeepCopy(self):
    root = tracking.Checkpointable()
    orig_dict = {"a": [1.]}
    root.a = orig_dict
    copied = copy.deepcopy(root.a)
    self.assertAllEqual([1.], copied["a"])
    self.assertIsNot(root.a, copied)
    self.assertIsNot(root.a["a"], copied["a"])

    # Dirtiness should be inherited
    util.list_objects(root.a)
    orig_dict["b"] = []
    with self.assertRaises(ValueError):
      util.list_objects(root.a)
    with self.assertRaises(ValueError):
      util.list_objects(copy.deepcopy(root.a))
Пример #23
0
  def testListDeepCopy(self):
    root = tracking.Checkpointable()
    orig_list = [[1.]]
    root.a = orig_list
    copied = copy.deepcopy(root.a)
    self.assertAllEqual([[1.]], copied)
    self.assertIsNot(root.a, copied)
    self.assertIsNot(root.a[0], copied[0])

    # Dirtiness should be inherited
    util.list_objects(root.a)
    orig_list.append(1.)
    with self.assertRaises(ValueError):
      util.list_objects(root.a)
    with self.assertRaises(ValueError):
      util.list_objects(copy.deepcopy(root.a))
Пример #24
0
def _make_graph_def(root, signature_functions, object_saver):
    """Generates and exports call ops for `signature_functions`."""
    signatures = {}
    # List objects from the eager context to make sure Optimizers give us the
    # right Graph-dependent variables.
    accessible_objects = util.list_objects(root)
    exported_graph = ops.Graph()
    with exported_graph.as_default():
        object_map, resource_map = _map_resources(accessible_objects)
    # Saving an object-based checkpoint again gathers variables. We need to do the
    # gathering from the eager context so Optimizers save the right set of
    # variables, but want any operations associated with the save/restore to be in
    # the exported graph (thus the `to_graph` argument).
    saver = object_saver.freeze(object_map=object_map, to_graph=exported_graph)
    with exported_graph.as_default():
        signatures = _generate_signatures(signature_functions, resource_map)
        saver_def = saver.to_proto()
    graph_def = exported_graph.as_graph_def(add_shapes=True)
    # Clean reference cycles so repeated export()s don't make work for the garbage
    # collector.
    ops.dismantle_graph(exported_graph)
    return graph_def, signatures, saver_def
Пример #25
0
def _make_graph_def(root, signature_functions, object_saver):
  """Generates and exports call ops for `signature_functions`."""
  signatures = {}
  # List objects from the eager context to make sure Optimizers give us the
  # right Graph-dependent variables.
  accessible_objects = util.list_objects(root)
  exported_graph = ops.Graph()
  with exported_graph.as_default():
    object_map, resource_map = _map_resources(accessible_objects)
  # Saving an object-based checkpoint again gathers variables. We need to do the
  # gathering from the eager context so Optimizers save the right set of
  # variables, but want any operations associated with the save/restore to be in
  # the exported graph (thus the `to_graph` argument).
  saver = object_saver.freeze(object_map=object_map, to_graph=exported_graph)
  with exported_graph.as_default():
    signatures = _generate_signatures(signature_functions, resource_map)
    saver_def = saver.to_proto()
  graph_def = exported_graph.as_graph_def(add_shapes=True)
  # Clean reference cycles so repeated export()s don't make work for the garbage
  # collector.
  ops.dismantle_graph(exported_graph)
  return graph_def, signatures, saver_def
    def test_bidirectional(self):
        rnn = keras.layers.SimpleRNN
        samples = 2
        dim = 2
        timesteps = 2
        output_dim = 2
        with self.cached_session():
            for mode in ['sum', 'concat', 'ave', 'mul']:
                x = np.random.random((samples, timesteps, dim))
                target_dim = 2 * output_dim if mode == 'concat' else output_dim
                y = np.random.random((samples, target_dim))

                # test with Sequential model
                model = keras.models.Sequential()
                model.add(
                    keras.layers.Bidirectional(rnn(output_dim),
                                               merge_mode=mode,
                                               input_shape=(timesteps, dim)))
                model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse')
                model.fit(x, y, epochs=1, batch_size=1)

                # check whether the model variables are present in the
                # checkpointable list of objects
                checkpointed_objects = set(
                    checkpointable_util.list_objects(model))
                for v in model.variables:
                    self.assertIn(v, checkpointed_objects)

                # test compute output shape
                ref_shape = model.layers[-1].output.get_shape()
                shape = model.layers[-1].compute_output_shape(
                    (None, timesteps, dim))
                self.assertListEqual(shape.as_list(), ref_shape.as_list())

                # test config
                model.get_config()
                model = keras.models.model_from_json(model.to_json())
                model.summary()
Пример #27
0
  def test_bidirectional(self):
    rnn = keras.layers.SimpleRNN
    samples = 2
    dim = 2
    timesteps = 2
    output_dim = 2
    with self.cached_session():
      for mode in ['sum', 'concat', 'ave', 'mul']:
        x = np.random.random((samples, timesteps, dim))
        target_dim = 2 * output_dim if mode == 'concat' else output_dim
        y = np.random.random((samples, target_dim))

        # test with Sequential model
        model = keras.models.Sequential()
        model.add(
            keras.layers.Bidirectional(
                rnn(output_dim), merge_mode=mode, input_shape=(timesteps, dim)))
        model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse')
        model.fit(x, y, epochs=1, batch_size=1)

        # check whether the model variables are present in the
        # checkpointable list of objects
        checkpointed_objects = set(checkpointable_util.list_objects(model))
        for v in model.variables:
          self.assertIn(v, checkpointed_objects)

        # test compute output shape
        ref_shape = model.layers[-1].output.get_shape()
        shape = model.layers[-1].compute_output_shape(
            (None, timesteps, dim))
        self.assertListEqual(shape.as_list(), ref_shape.as_list())

        # test config
        model.get_config()
        model = keras.models.model_from_json(model.to_json())
        model.summary()
Пример #28
0
def _fill_meta_graph_def(meta_graph_def, obj, signature_functions,
                         object_saver):
    """Generates a MetaGraph which calls `signature_functions`.

  Args:
    meta_graph_def: The MetaGraphDef proto to fill.
    obj: The checkpointable object being exported.
    signature_functions: A dictionary mapping signature keys to concrete
      functions containing signatures to add to the MetaGraph.
    object_saver: A CheckpointableSaver to add to the MetaGraph.

  Returns:
    An _AssetInfo, which contains information to help creating the SavedModel.
  """
    signatures = {}
    # List objects from the eager context to make sure Optimizers give us the
    # right Graph-dependent variables.
    accessible_objects = util.list_objects(obj)
    resource_initializer_functions = _trace_resource_initializers(
        accessible_objects)
    exported_graph = ops.Graph()
    resource_initializer_ops = []
    with exported_graph.as_default():
        object_map, resource_map, asset_info = _map_resources(
            accessible_objects)
        for resource_initializer_function in resource_initializer_functions:
            asset_dependencies = []
            for capture in resource_initializer_function.graph.external_captures:
                asset_initializer = asset_info.asset_initializers_by_resource.get(
                    capture, None)
                if asset_initializer is not None:
                    asset_dependencies.append(asset_initializer)
            with ops.control_dependencies(asset_dependencies):
                resource_initializer_ops.append(
                    _call_function_with_mapped_captures(
                        resource_initializer_function, [], resource_map))
        with ops.control_dependencies(resource_initializer_ops):
            init_op = control_flow_ops.no_op()
        # Add the same op to the main_op collection and to the init_op
        # signature. The collection is for compatibility with older loader APIs;
        # only one will be executed.
        meta_graph_def.collection_def[
            constants.MAIN_OP_KEY].node_list.value.append(init_op.name)
        meta_graph_def.signature_def[constants.INIT_OP_SIGNATURE_KEY].CopyFrom(
            signature_def_utils.op_signature_def(
                init_op, constants.INIT_OP_SIGNATURE_KEY))

    # Saving an object-based checkpoint again gathers variables. We need to do the
    # gathering from the eager context so Optimizers save the right set of
    # variables, but want any operations associated with the save/restore to be in
    # the exported graph (thus the `to_graph` argument).
    saver = object_saver.freeze(object_map=object_map, to_graph=exported_graph)

    # We must resolve the concrete function to add to MetaGraph while in eager
    # mode.
    concrete_functions = []
    for accessible_object in accessible_objects:
        for function in function_serialization.list_all_polymorphic_functions(
                accessible_object).values():
            concrete_functions.extend(
                function_serialization.list_all_concrete_functions(function))

    with exported_graph.as_default():
        signatures = _generate_signatures(signature_functions, resource_map)
        for concrete_function in concrete_functions:
            concrete_function.add_to_graph()
        saver_def = saver.to_proto()
        meta_graph_def.saver_def.CopyFrom(saver_def)
    graph_def = exported_graph.as_graph_def(add_shapes=True)
    # Clean reference cycles so repeated export()s don't make work for the garbage
    # collector.
    ops.dismantle_graph(exported_graph)

    meta_graph_def.graph_def.CopyFrom(graph_def)
    meta_graph_def.meta_info_def.tags.append(tag_constants.SERVING)
    meta_graph_def.asset_file_def.extend(asset_info.asset_defs)
    for signature_key, signature in signatures.items():
        meta_graph_def.signature_def[signature_key].CopyFrom(signature)
    meta_graph.strip_graph_default_valued_attrs(meta_graph_def)
    return asset_info
Пример #29
0
def _fill_meta_graph_def(meta_graph_def, obj, signature_functions,
                         object_saver):
  """Generates a MetaGraph which calls `signature_functions`.

  Args:
    meta_graph_def: The MetaGraphDef proto to fill.
    obj: The checkpointable object being exported.
    signature_functions: A dictionary mapping signature keys to concrete
      functions containing signatures to add to the MetaGraph.
    object_saver: A CheckpointableSaver to add to the MetaGraph.

  Returns:
    asset_filename_map, a dictionary mapping from asset base names to
    user-specified full asset paths, which should be copied to the SavedModel's
    assets/ directory.
  """
  signatures = {}
  # List objects from the eager context to make sure Optimizers give us the
  # right Graph-dependent variables.
  accessible_objects = util.list_objects(obj)
  resource_initializer_functions = _trace_resource_initializers(
      accessible_objects)
  exported_graph = ops.Graph()
  resource_initializer_ops = []
  with exported_graph.as_default():
    object_map, resource_map, asset_info = _map_resources(accessible_objects)
    for resource_initializer_function in resource_initializer_functions:
      asset_dependencies = []
      for capture in resource_initializer_function.graph.external_captures:
        asset_initializer = asset_info.asset_initializers_by_resource.get(
            capture, None)
        if asset_initializer is not None:
          asset_dependencies.append(asset_initializer)
      with ops.control_dependencies(asset_dependencies):
        resource_initializer_ops.append(
            _call_function_with_mapped_captures(
                resource_initializer_function, [], resource_map))
    with ops.control_dependencies(resource_initializer_ops):
      init_op = control_flow_ops.no_op()
    # Add the same op to the main_op collection and to the init_op
    # signature. The collection is for compatibility with older loader APIs;
    # only one will be executed.
    meta_graph_def.collection_def[constants.MAIN_OP_KEY].node_list.value.append(
        init_op.name)
    meta_graph_def.signature_def[constants.INIT_OP_SIGNATURE_KEY].CopyFrom(
        signature_def_utils.op_signature_def(
            init_op, constants.INIT_OP_SIGNATURE_KEY))

  # Saving an object-based checkpoint again gathers variables. We need to do the
  # gathering from the eager context so Optimizers save the right set of
  # variables, but want any operations associated with the save/restore to be in
  # the exported graph (thus the `to_graph` argument).
  saver = object_saver.freeze(object_map=object_map, to_graph=exported_graph)
  with exported_graph.as_default():
    signatures = _generate_signatures(signature_functions, resource_map)
    saver_def = saver.to_proto()
    meta_graph_def.saver_def.CopyFrom(saver_def)
  graph_def = exported_graph.as_graph_def(add_shapes=True)
  # Clean reference cycles so repeated export()s don't make work for the garbage
  # collector.
  ops.dismantle_graph(exported_graph)

  meta_graph_def.graph_def.CopyFrom(graph_def)
  meta_graph_def.meta_info_def.tags.append(tag_constants.SERVING)
  meta_graph_def.asset_file_def.extend(asset_info.asset_defs)
  for signature_key, signature in signatures.items():
    meta_graph_def.signature_def[signature_key].CopyFrom(signature)
  meta_graph.strip_graph_default_valued_attrs(meta_graph_def)
  return asset_info.asset_filename_map