def _warm_start_var(var, prev_ckpt, prev_tensor_name=None):
  """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.

  Args:
    var: Current graph's variable that needs to be warm-started (initialized).
      Can be either of the following:
      (i) `Variable`
      (ii) `ResourceVariable`
      (iii) list of `Variable`: The list must contain slices of the same larger
        variable.
      (iv) `PartitionedVariable`
    prev_ckpt: A string specifying the directory with checkpoint file(s) or path
      to checkpoint. The given checkpoint must have tensor with name
      `prev_tensor_name` (if not None) or tensor with name same as given `var`.
    prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
      None, we lookup tensor with same name as given `var`.
  """
  if checkpoint_utils._is_variable(var):  # pylint: disable=protected-access
    current_var_name = _infer_var_name([var])
  elif (isinstance(var, list) and
        all(checkpoint_utils._is_variable(v) for v in var)):  # pylint: disable=protected-access
    current_var_name = _infer_var_name(var)
  elif isinstance(var, variables_lib.PartitionedVariable):
    current_var_name = _infer_var_name([var])
    var = var._get_variable_list()  # pylint: disable=protected-access
  else:
    raise TypeError(
        "var MUST be one of the following: a Variable, list of Variable or "
        "PartitionedVariable, but is {}".format(type(var)))
  if not prev_tensor_name:
    # Assume tensor name remains the same.
    prev_tensor_name = current_var_name
  checkpoint_utils.init_from_checkpoint(prev_ckpt, {prev_tensor_name: var})
  def testNoAdditionalReadOpsForResourceVariables(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.session(graph=g) as session:
        my1 = resource_variable_ops.ResourceVariable([[0.0] * 10], name="my1")

        with ops.name_scope("init_from_checkpoint"):
          checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": my1})

        # Basic sanity checks:
        session.run(variables.global_variables_initializer())
        self.assertAllEqual(session.run(my1), v1)

    ops_in_init_from_checkpoint_scope = [
        op for op in g.get_operations()
        if (op.name.startswith("init_from_checkpoint/") and
            not op.name.startswith("init_from_checkpoint/checkpoint_initializer"
                                  ) and
            op.type != "AssignVariableOp" and
            op.type != "Identity")
    ]
    self.assertEqual(ops_in_init_from_checkpoint_scope, [])
示例#3
0
    def testInitialValueComesFromCheckpoint(self):
        checkpoint_dir = self.get_temp_dir()
        with self.test_session() as session:
            v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)

        # New graph and session.
        with ops.Graph().as_default() as g:
            with self.session(graph=g) as session:
                with variable_scope.variable_scope(
                        "some_scope",
                        initializer=init_ops.zeros_initializer()):
                    my1 = variable_scope.get_variable("my1", [1, 10])

                before = my1.initialized_value()

                checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                                      {"var1": my1})

                after = my1.initialized_value()

                self.assertAllEqual(session.run(before), [[0.0] * 10])
                self.assertAllEqual(session.run(after), v1)

                session.run(variables.global_variables_initializer())

                self.assertAllEqual(session.run(my1), v1)
                self.assertAllEqual(session.run(my1.initialized_value()), v1)
                self.assertAllClose(session.run(before), v1)
                self.assertAllClose(session.run(after), v1)
                with self.assertRaises(AssertionError):
                    self.assertAllClose(v1, [[0.0] * 10])
  def testInitialValueComesFromCheckpoint(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with variable_scope.variable_scope(
            "some_scope", initializer=init_ops.zeros_initializer()):
          my1 = variable_scope.get_variable("my1", [1, 10])

        before = my1.initialized_value()

        checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": my1})

        after = my1.initialized_value()

        self.assertAllEqual(session.run(before), [[0.0] * 10])
        self.assertAllEqual(session.run(after), v1)

        session.run(variables.global_variables_initializer())

        self.assertAllEqual(session.run(my1), v1)
        self.assertAllEqual(session.run(my1.initialized_value()), v1)
        self.assertAllClose(session.run(before), v1)
        self.assertAllClose(session.run(after), v1)
        with self.assertRaises(AssertionError):
          self.assertAllClose(v1, [[0.0] * 10])
示例#5
0
def _warm_start_var(var, prev_ckpt, prev_tensor_name=None):
    """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.

  Args:
    var: Current graph's variable that needs to be warm-started (initialized).
      Can be either of the following:
      (i) `Variable`
      (ii) `ResourceVariable`
      (iii) list of `Variable`: The list must contain slices of the same larger
        variable.
      (iv) `PartitionedVariable`
    prev_ckpt: A string specifying the directory with checkpoint file(s) or path
      to checkpoint. The given checkpoint must have tensor with name
      `prev_tensor_name` (if not None) or tensor with name same as given `var`.
    prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
      None, we lookup tensor with same name as given `var`.
  """
    if _is_variable(var):
        current_var_name = _infer_var_name([var])
    elif isinstance(var, list) and all(_is_variable(v) for v in var):
        current_var_name = _infer_var_name(var)
    elif isinstance(var, variables_lib.PartitionedVariable):
        current_var_name = _infer_var_name([var])
        var = var._get_variable_list()  # pylint: disable=protected-access
    else:
        raise TypeError(
            "var MUST be one of the following: a Variable, list of Variable or "
            "PartitionedVariable, but is {}".format(type(var)))
    if not prev_tensor_name:
        # Assume tensor name remains the same.
        prev_tensor_name = current_var_name
    checkpoint_utils.init_from_checkpoint(prev_ckpt, {prev_tensor_name: var})
  def testInitFromCheckpoint(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with variable_scope.variable_scope("some_scope"):
          my1 = variable_scope.get_variable("my1", [1, 10])
          with variable_scope.variable_scope("some_other_scope"):
            my2 = variable_scope.get_variable("my2", [10, 10])
            with variable_scope.variable_scope("other_useful_scope"):
              my4 = variable_scope.get_variable("var4", [9, 9])
        my3 = variable_scope.get_variable("my3", [100, 100])

        checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
            "var1": "some_scope/my1",
            "useful_scope/": "some_scope/some_other_scope/other_useful_scope/",
        })
        checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
            "var2": "some_scope/some_other_scope/my2",
            "var3": my3,
        })

        session.run(variables.global_variables_initializer())
        self.assertAllEqual(my1.eval(session), v1)
        self.assertAllEqual(my2.eval(session), v2)
        self.assertAllEqual(my3.eval(session), v3)
        self.assertAllEqual(my4.eval(session), v4)

        # Check that tensors are not explicitly in the graph.
        self.assertLess(len(str(session.graph.as_graph_def())), 29000)
  def testInitFromCheckpoint(self):
    checkpoint_dir = self.get_temp_dir()
    with self.cached_session() as session:
      v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.session(graph=g) as session:
        with variable_scope.variable_scope("some_scope"):
          my1 = variable_scope.get_variable("my1", [1, 10])
          with variable_scope.variable_scope("some_other_scope"):
            my2 = variable_scope.get_variable("my2", [10, 10])
            with variable_scope.variable_scope("other_useful_scope"):
              my4 = variable_scope.get_variable("var4", [9, 9])
        my3 = variable_scope.get_variable("my3", [100, 100])

        checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
            "var1": "some_scope/my1",
            "useful_scope/": "some_scope/some_other_scope/other_useful_scope/",
        })
        checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
            "var2": "some_scope/some_other_scope/my2",
            "var3": my3,
        })

        session.run(variables.global_variables_initializer())
        self.assertAllEqual(my1.eval(session), v1)
        self.assertAllEqual(my2.eval(session), v2)
        self.assertAllEqual(my3.eval(session), v3)
        self.assertAllEqual(my4.eval(session), v4)

        # Check that tensors are not explicitly in the graph.
        self.assertLess(len(str(session.graph.as_graph_def())), 29000)
  def testInitialValueComesFromCheckpoint(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with variable_scope.variable_scope(
            "some_scope", initializer=init_ops.zeros_initializer()):
          my1 = variable_scope.get_variable("my1", [1, 10])

        # At this point, my1.initialized_value() will add ops that reference
        # the zeros initializer of my1.
        before = variables.Variable(my1.initialized_value(), name="before")

        checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": my1})

        # At this point, my1.initialized_value() will add ops that reference
        # the newly set initializer of my1.
        after = variables.Variable(my1.initialized_value(), name="after")

        session.run(variables.global_variables_initializer())
        self.assertAllEqual(session.run(my1), v1)
        self.assertAllEqual(session.run(my1.initialized_value()), v1)
        self.assertAllClose(session.run(before), [[0.0] * 10])
        self.assertAllClose(session.run(after), v1)
        with self.assertRaises(AssertionError):
          self.assertAllClose(session.run(before), session.run(after))
  def _get_dense_tensor(self,inputs,weight_collections=None,trainable=None):
    """Private method that follows the signature of _get_dense_tensor."""
    # Get sparse IDs and weights.
    sparse_tensors = self.categorical_column._get_sparse_tensors(  # pylint: disable=protected-access
        inputs, weight_collections=weight_collections, trainable=trainable)
    sparse_ids = sparse_tensors.id_tensor
    sparse_weights = sparse_tensors.weight_tensor

    candidate_dense_tensors = self._get_candidate_dense_tensor(inputs,weight_collections,trainable)

    embedding_weights = self.layer_creator(
        weight_collections=weight_collections,
        scope=variable_scope.get_variable_scope())

    if self.ckpt_to_load_from is not None:
      to_restore = embedding_weights
      if isinstance(to_restore, variables.PartitionedVariable):
        to_restore = to_restore._get_variable_list()  # pylint: disable=protected-access
      checkpoint_utils.init_from_checkpoint(self.ckpt_to_load_from, {
          self.tensor_name_in_ckpt: to_restore
      })

    # Return embedding lookup result.
    return attention_safe_embedding_lookup_sparse(
        embedding_weights=embedding_weights,
        sparse_ids=sparse_ids,
        sparse_weights=sparse_weights,
        candidate_dense_tensors = candidate_dense_tensors,
        combiner=self.combiner,
        name='%s_weights' % self.name,
        max_norm=self.max_norm)
示例#10
0
    def testInitialValueComesFromCheckpoint(self):
        checkpoint_dir = self.get_temp_dir()
        with self.test_session() as session:
            v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)

        # New graph and session.
        with ops.Graph().as_default() as g:
            with self.test_session(graph=g) as session:
                with variable_scope.variable_scope(
                        "some_scope",
                        initializer=init_ops.zeros_initializer()):
                    my1 = variable_scope.get_variable("my1", [1, 10])

                # At this point, my1.initialized_value() will add ops that reference
                # the zeros initializer of my1.
                before = variables.Variable(my1.initialized_value(),
                                            name="before")

                checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                                      {"var1": my1})

                # At this point, my1.initialized_value() will add ops that reference
                # the newly set initializer of my1.
                after = variables.Variable(my1.initialized_value(),
                                           name="after")

                session.run(variables.global_variables_initializer())
                self.assertAllEqual(session.run(my1), v1)
                self.assertAllEqual(session.run(my1.initialized_value()), v1)
                self.assertAllClose(session.run(before), [[0.0] * 10])
                self.assertAllClose(session.run(after), v1)
                with self.assertRaises(AssertionError):
                    self.assertAllClose(session.run(before),
                                        session.run(after))
示例#11
0
    def testNoAdditionalReadOpsForResourceVariables(self):
        checkpoint_dir = self.get_temp_dir()
        with self.test_session() as session:
            v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)

        # New graph and session.
        with ops.Graph().as_default() as g:
            with self.session(graph=g) as session:
                my1 = resource_variable_ops.ResourceVariable([[0.0] * 10],
                                                             name="my1")

                with ops.name_scope("init_from_checkpoint"):
                    checkpoint_utils.init_from_checkpoint(
                        checkpoint_dir, {"var1": my1})

                # Basic sanity checks:
                session.run(variables.global_variables_initializer())
                self.assertAllEqual(session.run(my1), v1)

        ops_in_init_from_checkpoint_scope = [
            op for op in g.get_operations()
            if (op.name.startswith("init_from_checkpoint/") and not op.name.
                startswith("init_from_checkpoint/checkpoint_initializer")
                and op.type != "AssignVariableOp" and op.type != "Identity")
        ]
        self.assertEqual(ops_in_init_from_checkpoint_scope, [])
示例#12
0
 def init_and_verify(g):
   v1 = variable_scope.get_variable("new_var1", [1, 10])
   checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
       "var1": "new_var1",
   })
   with self.test_session(graph=g) as session:
     session.run(variables.global_variables_initializer())
     self.assertAllEqual(v1_value, self.evaluate(v1))
def init_lm_checkpoints(lm_dirs):
    assignment_map = {'LanguageModel/': 'LanguageModel/'}
    init_from_checkpoint(os.path.join(lm_dirs['forward'], 'ckpt'),
                         assignment_map=assignment_map)
    if lm_dirs['reverse'] is not None:
        assignment_map = {'LanguageModel/': 'LanguageModelReverse/'}
        init_from_checkpoint(os.path.join(lm_dirs['reverse'], 'ckpt'),
                             assignment_map=assignment_map)
示例#14
0
    def testRestoreRunsOnSameDevice(self):
        checkpoint_dir = self.get_temp_dir()
        with self.cached_session() as session:
            _create_checkpoints(session, checkpoint_dir)

        with ops.Graph().as_default():
            with ops.device("/job:ps"):
                with variable_scope.variable_scope("useful_scope"):
                    variable_scope.get_variable("var4", [9, 9])

            checkpoint_utils.init_from_checkpoint(
                checkpoint_dir, {"useful_scope/": "useful_scope/"})
示例#15
0
 def init_and_verify(g):
   v1 = variable_scope.get_variable("new_var1", [1, 10])
   # Use string add to create new object in each replica
   prefix = "new_"
   suffix = "var1"
   new_var1 = prefix + suffix
   checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
       "var1": new_var1,
   })
   with self.test_session(graph=g) as session:
     session.run(variables.global_variables_initializer())
     self.assertAllEqual(v1_value, self.evaluate(v1))
 def init_and_verify(g):
     v1 = variable_scope.get_variable("new_var1", [1, 10])
     # Use string add to create new object in each replica
     prefix = "new_"
     suffix = "var1"
     new_var1 = prefix + suffix
     checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
         "var1": new_var1,
     })
     with self.test_session(graph=g) as session:
         session.run(variables.global_variables_initializer())
         self.assertAllEqual(v1_value, self.evaluate(v1))
  def testRestoreRunsOnSameDevice(self):
    checkpoint_dir = self.get_temp_dir()
    with self.cached_session() as session:
      _create_checkpoints(session, checkpoint_dir)

    with ops.Graph().as_default():
      with ops.device("/job:ps"):
        with variable_scope.variable_scope("useful_scope"):
          my4 = variable_scope.get_variable("var4", [9, 9])

      checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                            {"useful_scope/": "useful_scope/"})
 def init_and_verify(g):
     v1 = variable_scope.get_variable("new_var1", [1, 10])
     v2 = variable_scope.get_variable(
         "new_var2", [10, 10],
         synchronization=variable_scope.VariableSynchronization.ON_READ,
         aggregation=variable_scope.VariableAggregation.MEAN)
     checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
         "var1": "new_var1",
         "var2": "new_var2"
     })
     with self.session(graph=g) as session:
         session.run(variables.global_variables_initializer())
         self.assertAllEqual(v1_value, self.evaluate(v1))
         self.assertAllEqual(v2_value, self.evaluate(v2))
示例#19
0
    def testRestoreRunsOnSameDevice(self):
        checkpoint_dir = self.get_temp_dir()
        with self.test_session() as session:
            _create_checkpoints(session, checkpoint_dir)

        with ops.Graph().as_default():
            with ops.device("/job:ps"):
                with variable_scope.variable_scope("useful_scope"):
                    my4 = variable_scope.get_variable("var4", [9, 9])

            checkpoint_utils.init_from_checkpoint(
                checkpoint_dir, {"useful_scope/": "useful_scope/"})
            self.assertEqual(my4._initializer_op.op.inputs[1].device,
                             "/job:ps")
示例#20
0
 def init_and_verify(g):
   v1 = variable_scope.get_variable("new_var1", [1, 10])
   v2 = variable_scope.get_variable(
       "new_var2", [10, 10],
       synchronization=variable_scope.VariableSynchronization.ON_READ,
       aggregation=variable_scope.VariableAggregation.MEAN)
   checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
       "var1": "new_var1",
       "var2": "new_var2"
   })
   with self.session(graph=g) as session:
     session.run(variables.global_variables_initializer())
     self.assertAllEqual(v1_value, self.evaluate(v1))
     self.assertAllEqual(v2_value, self.evaluate(v2))
  def testRestoreRunsOnSameDevice(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      _create_checkpoints(session, checkpoint_dir)

    with ops.Graph().as_default():
      with ops.device("/job:ps"):
        with variable_scope.variable_scope("useful_scope"):
          my4 = variable_scope.get_variable("var4", [9, 9])

      checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                            {"useful_scope/": "useful_scope/"})
      # initializer runs on the same task but always on CPU.
      self.assertEqual(my4._initializer_op.op.inputs[1].device,
                       "/job:ps/device:CPU:0")
示例#22
0
    def testInitWithScopeDoesNotCaptureSuffixes(self):
        checkpoint_dir = self.get_temp_dir()
        with self.test_session() as session:
            _, _, _, v4 = _create_checkpoints(session, checkpoint_dir)

        with ops.Graph().as_default() as g:
            with variable_scope.variable_scope("useful_scope"):
                my4 = variable_scope.get_variable("var4", [9, 9])
            with variable_scope.variable_scope("useful_scope_1"):
                my5_init = [[1.0, 2.0], [3.0, 4.0]]
                my5 = variable_scope.get_variable("var5", initializer=my5_init)

            checkpoint_utils.init_from_checkpoint(
                checkpoint_dir, {"useful_scope/": "useful_scope/"})
            with self.session(graph=g) as session:
                session.run(variables.global_variables_initializer())
                self.assertAllEqual(my4.eval(session), v4)
                self.assertAllEqual(my5.eval(session), my5_init)
  def testInitWithScopeDoesNotCaptureSuffixes(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      _, _, _, v4 = _create_checkpoints(session, checkpoint_dir)

    with ops.Graph().as_default() as g:
      with variable_scope.variable_scope("useful_scope"):
        my4 = variable_scope.get_variable("var4", [9, 9])
      with variable_scope.variable_scope("useful_scope_1"):
        my5_init = [[1.0, 2.0], [3.0, 4.0]]
        my5 = variable_scope.get_variable("var5", initializer=my5_init)

      checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                            {"useful_scope/": "useful_scope/"})
      with self.test_session(graph=g) as session:
        session.run(variables.global_variables_initializer())
        self.assertAllEqual(my4.eval(session), v4)
        self.assertAllEqual(my5.eval(session), my5_init)
  def testInitToRootCheckpoint(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        my1 = variable_scope.get_variable("var1", [1, 10])
        my2 = variable_scope.get_variable("var2", [10, 10])
        my3 = variable_scope.get_variable("var3", [100, 100])
        with variable_scope.variable_scope("useful_scope"):
          my4 = variable_scope.get_variable("var4", [9, 9])

        checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                              {"/": "/",})

        session.run(variables.global_variables_initializer())
        self.assertAllEqual(my1.eval(session), v1)
        self.assertAllEqual(my2.eval(session), v2)
        self.assertAllEqual(my3.eval(session), v3)
        self.assertAllEqual(my4.eval(session), v4)
示例#25
0
  def testInitToRootCheckpoint(self):
    checkpoint_dir = self.get_temp_dir()
    with self.cached_session() as session:
      v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.session(graph=g) as session:
        my1 = variable_scope.get_variable("var1", [1, 10])
        my2 = variable_scope.get_variable("var2", [10, 10])
        my3 = variable_scope.get_variable("var3", [100, 100])
        with variable_scope.variable_scope("useful_scope"):
          my4 = variable_scope.get_variable("var4", [9, 9])

        checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                              {"/": "/",})

        session.run(variables.global_variables_initializer())
        self.assertAllEqual(my1.eval(session), v1)
        self.assertAllEqual(my2.eval(session), v2)
        self.assertAllEqual(my3.eval(session), v3)
        self.assertAllEqual(my4.eval(session), v4)
  def testInitFromCheckpointMissing(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      _, _, _, _ = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with variable_scope.variable_scope("some_scope"):
          _ = variable_scope.get_variable("my1", [10, 10])
          _ = variable_scope.get_variable(
              "my2", [1, 10],
              dtype=dtypes.int64,
              initializer=init_ops.zeros_initializer())

        # No directory.
        with self.assertRaises(errors_impl.OpError):
          checkpoint_utils.init_from_checkpoint("no_dir",
                                                {"var1": "some_scope/my1"})

        # No variable in checkpoint.
        with self.assertRaises(ValueError):
          checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                                {"no_var": "some_scope/my1"})

        # No variable in the graph.
        with self.assertRaises(ValueError):
          checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                                {"var3": "some_scope/no_var"})

        # Shape mismatch.
        with self.assertRaises(ValueError):
          checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                                {"var1": "some_scope/my1"})

        # Variable 'my1' and 'my2' are missing in given checkpoint scope.
        with self.assertRaises(ValueError):
          checkpoint_utils.init_from_checkpoint(
              checkpoint_dir, {"useful_scope/": "some_scope/"})

        # Mapping is not to scope name.
        with self.assertRaises(ValueError):
          checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                                {"useful_scope": "some_scope/"})
  def testInitFromPartitionVar(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1 = _create_partition_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with variable_scope.variable_scope("some_scope"):
          my1 = variable_scope.get_variable(
              name="my1",
              shape=[100, 100],
              initializer=init_ops.zeros_initializer(),
              partitioner=partitioned_variables.min_max_variable_partitioner(
                  max_partitions=5, axis=0, min_slice_size=8 << 10))
          my1_var_list = my1._get_variable_list()
        # Create another variable with different partitions than the variable in
        # the checkpoint.
        with variable_scope.variable_scope("some_other_scope"):
          my2 = variable_scope.get_variable(
              name="var1",
              shape=[100, 100],
              initializer=init_ops.zeros_initializer(),
              partitioner=partitioned_variables.min_max_variable_partitioner(
                  max_partitions=5, axis=0, min_slice_size=16 << 10))
          my2_var_list = my2._get_variable_list()

        checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
            "scope/var1": "some_scope/my1",
            "scope/": "some_other_scope/"})

        session.run(variables.global_variables_initializer())
        my1_values = session.run(my1_var_list)
        self.assertAllEqual(my1_values, v1)
        my2_values = session.run(my2_var_list)
        # Verify we created different number of partitions.
        self.assertNotEquals(len(my2_values), len(v1))
        # Verify the values were correctly initialized inspite of different
        # partitions.
        full_my2_values = np.concatenate(my2_values, axis=0)
        full_v1_values = np.concatenate(v1, axis=0)
        self.assertAllEqual(full_my2_values, full_v1_values)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with variable_scope.variable_scope("some_scope"):
          my1 = variable_scope.get_variable(
              name="my1",
              shape=[100, 100],
              initializer=init_ops.truncated_normal_initializer(0.5),
              partitioner=partitioned_variables.min_max_variable_partitioner(
                  max_partitions=5, axis=0, min_slice_size=8 << 10))
          my1_var_list = my1._get_variable_list()

        checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                              {"scope/var1": my1_var_list,})

        session.run(variables.global_variables_initializer())
        my1_values = session.run(my1_var_list)
        self.assertAllEqual(my1_values, v1)
示例#28
0
    def _get_dense_tensor(self,
                          inputs,
                          weight_collections=None,
                          trainable=None):
        #Get sparse IDs and weights.
        sparse_tensors = self.categorical_column._get_sparse_tensors(  #pylint: disable=protected-access
            inputs,
            weight_collections=weight_collections,
            trainable=trainable)
        sparse_ids = sparse_tensors.id_tensor
        batch_size = sparse_ids.dense_shape[0]
        dense_tensor_ids = sparse_ops.sparse_to_dense(
            sparse_ids.indices, [batch_size, self.max_sequence_length],
            sparse_ids.values,
            default_value=0)

        # Create embedding weight, and restore from checkpoint if necessary.
        embedding_weights = variable_scope.get_variable(
            name='embedding_weights',
            shape=(self.categorical_column._num_buckets,
                   self.embedding_dimension),  # pylint: disable=protected-access
            dtype=dtypes.float32,
            initializer=self.initializer,
            trainable=self.trainable and trainable,
            collections=weight_collections)
        if self.ckpt_to_load_from is not None:
            to_restore = embedding_weights
            if isinstance(to_restore, variables.PartitionedVariable):
                to_restore = to_restore._get_variable_list()  # pylint: disable=protected-access
            checkpoint_utils.init_from_checkpoint(
                self.ckpt_to_load_from, {self.tensor_name_in_ckpt: to_restore})

        #dense_tensor_ids = utils.tf_print(dense_tensor_ids, "dense:")
        embedding_inputs = embedding_lookup(embedding_weights,
                                            dense_tensor_ids,
                                            max_norm=self.max_norm)

        dropout = (self.dropout_keep_probabilities
                   if self.mode == model_fn_lib.ModeKeys.TRAIN else None)

        sequence_lengths = self._sequence_lengths(sparse_ids)
        if self.bidirectional_rnn:
            cell_fw = rnn_common.construct_rnn_cell(self.num_units,
                                                    self.cell_type, dropout)
            cell_bw = rnn_common.construct_rnn_cell(self.num_units,
                                                    self.cell_type, dropout)
            with ops.name_scope('RNN'):
                rnn_outputs, final_states = rnn.bidirectional_dynamic_rnn(
                    cell_fw,
                    cell_bw,
                    embedding_inputs,
                    sequence_length=sequence_lengths,
                    dtype=dtypes.float32)
                #outputs = layers.fully_connected(
                #    inputs=array_ops.concat(rnn_outputs, 2),
                #    num_outputs=self.num_units,
                #    activation_fn=self.activation_fn,
                #    trainable=True)
                return array_ops.concat(final_states, 1)
        else:
            cell = rnn_common.construct_rnn_cell(self.num_units,
                                                 self.cell_type, dropout)
            with ops.name_scope('RNN'):
                rnn_outputs, final_state = rnn.dynamic_rnn(
                    cell,
                    embedding_inputs,
                    sequence_length=sequence_lengths,
                    dtype=dtypes.float32)
                #rnn_outputs = utils.tf_print(rnn_outputs, "rnn_output:")
                #rnn_last_outputs = utils.tf_print(rnn_last_outputs, "rnn_last:")
                #outputs = layers.fully_connected(
                #    inputs=rnn_outputs,
                #    num_outputs=self.num_units,
                #    activation_fn=self.activation_fn,
                #    trainable=True)

                return final_state.h
def _warmstart_var(var, prev_ckpt, prev_tensor_name=None):
  """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.

  Args:
    var: Current graph's variable that needs to be warm-started (initialized).
      Can be either of the following:
      (i) `Variable`
      (ii) `ResourceVariable`
      (iii) `PartitionedVariable`
      (iv) list of `Variable` and/or `PartitionedVariable`: The list may
        contain one or more variables that has been sharded.  For example:
        [Variable('a/part_0'), Variable('b/part_0'), Variable('a/part_1'),
         PartitionedVariable([Variable('c/part_0'), Variable('c/part_1')])]
        where we have three whole Variables represented ('a', 'b', and 'c').
    prev_ckpt: A string specifying the directory with checkpoint file(s) or path
      to checkpoint. The given checkpoint must have tensor with name
      `prev_tensor_name` (if not None) or tensor with name same as given `var`.
    prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
      None, we lookup tensor with same name as given `var`.

  Raises:
    ValueError: If prev_tensor_name is not None, but the given var represents
      more than one Variable.
    TypeError: If var is not one of the allowed types.
  """
  if _is_variable(var):
    current_var_name = _infer_var_name([var])
  elif isinstance(var, variables.PartitionedVariable):
    current_var_name = _infer_var_name([var])
    var = var._get_variable_list()  # pylint: disable=protected-access
  elif (isinstance(var, list) and all(
      _is_variable(v) or isinstance(v, variables.PartitionedVariable)
      for v in var)):
    # Convert length-1 lists of vars to single tf.Variables.  This ensures that
    # checkpoint_utils.init_from_checkpoint() doesn't incorrectly assume
    # slice info is present.
    if len(var) == 1:
      current_var_name = _infer_var_name(var)
      var = var[0]
    else:
      # If we have multiple elements in var, we cannot assume they all
      # represent the same Variable.
      name_to_var_dict = saver.BaseSaverBuilder.OpListToDict(
          var, convert_variable_to_tensor=False)
      if prev_tensor_name:
        # Providing a prev_tensor_name is only viable if var representes a
        # single Variable.
        if len(name_to_var_dict) > 1:
          raise ValueError("var represented more than one Variable, but "
                           "prev_tensor_name was provided.")
        checkpoint_utils.init_from_checkpoint(prev_ckpt, {
            prev_tensor_name: var
        })
      else:
        # OpListToDict gives us roughly what we need, but
        # the values in the dict may be PartitionedVariables (which
        # init_from_checkpoint does not expect) that we need to convert to
        # lists.
        name_to_var_dict_fixed = {}
        for name, var in six.iteritems(name_to_var_dict):
          if isinstance(var, variables.PartitionedVariable):
            name_to_var_dict_fixed[name] = var._get_variable_list()  # pylint: disable=protected-access
          else:
            name_to_var_dict_fixed[name] = var
        checkpoint_utils.init_from_checkpoint(prev_ckpt, name_to_var_dict_fixed)
      return
  else:
    raise TypeError(
        "var MUST be one of the following: a Variable, PartitionedVariable, or "
        "list of Variable's and/or PartitionedVariable's, but is {}".format(
            type(var)))
  if not prev_tensor_name:
    # Assume tensor name remains the same.
    prev_tensor_name = current_var_name
  checkpoint_utils.init_from_checkpoint(prev_ckpt, {prev_tensor_name: var})
def _warm_start_variables(ckpt_to_initialize_from,
                          vars_to_warm_start,
                          var_name_to_vocab_info,
                          var_name_to_prev_var_name):
  grouped_variables = _get_grouped_variables(vars_to_warm_start)

  if var_name_to_vocab_info is None:
    var_name_to_vocab_info = {}

  if not var_name_to_prev_var_name:
    # Detect whether the checkpoint is object-based, in which case the
    # var_name_to_prev_var_name dictionary should map variable names to
    # checkpoint keys. If the user has specified var_name_to_prev_var_name, we
    # do not override it.
    var_name_to_prev_var_name = _get_object_checkpoint_renames(
        ckpt_to_initialize_from, grouped_variables.keys())

  warmstarted_count = 0

  # Keep track of which var_names in var_name_to_prev_var_name and
  # var_name_to_vocab_info have been used.  Err on the safer side by throwing an
  # exception if any are unused by the end of the loop.  It is easy to misname
  # a variable during this configuration, in which case without this check, we
  # would fail to warm-start silently.
  prev_var_name_used = set()
  vocab_info_used = set()

  # Group the vocabless vars into one call to init_from_checkpoint.
  vocabless_vars = {}
  for var_name, variable in six.iteritems(grouped_variables):
    prev_var_name = var_name_to_prev_var_name.get(var_name)
    if prev_var_name:
      prev_var_name_used.add(var_name)
    vocab_info = var_name_to_vocab_info.get(var_name)
    if vocab_info:
      vocab_info_used.add(var_name)
      warmstarted_count += 1
      logging.debug(
          "Warm-starting variable: {}; current_vocab: {} current_vocab_size: {}"
          " prev_vocab: {} prev_vocab_size: {} current_oov: {} prev_tensor: {}"
          " initializer: {}".format(
              var_name, vocab_info.new_vocab, vocab_info.new_vocab_size,
              vocab_info.old_vocab, (vocab_info.old_vocab_size if
                                     vocab_info.old_vocab_size > 0 else "All"),
              vocab_info.num_oov_buckets, prev_var_name or "Unchanged",
              vocab_info.backup_initializer or "zero-initialized"))
      _warm_start_var_with_vocab(
          variable,
          current_vocab_path=vocab_info.new_vocab,
          current_vocab_size=vocab_info.new_vocab_size,
          prev_ckpt=ckpt_to_initialize_from,
          prev_vocab_path=vocab_info.old_vocab,
          previous_vocab_size=vocab_info.old_vocab_size,
          current_oov_buckets=vocab_info.num_oov_buckets,
          prev_tensor_name=prev_var_name,
          initializer=vocab_info.backup_initializer,
          axis=vocab_info.axis)
    else:
      # For the special value of vars_to_warm_start = None,
      # we only warm-start variables with explicitly specified vocabularies.
      if vars_to_warm_start:
        warmstarted_count += 1
        logging.debug("Warm-starting variable: {}; prev_var_name: {}".format(
            var_name, prev_var_name or "Unchanged"))
        # Because we use a default empty list in grouped_variables, single
        # unpartitioned variables will be lists here, which we rectify in order
        # for init_from_checkpoint logic to work correctly.
        if len(variable) == 1:
          variable = variable[0]
        prev_tensor_name, var = _get_var_info(variable, prev_var_name)
        vocabless_vars[prev_tensor_name] = var

  checkpoint_utils.init_from_checkpoint(ckpt_to_initialize_from, vocabless_vars)
  vocab_info_not_used = set(var_name_to_vocab_info.keys()) - vocab_info_used

  logging.info("Warm-started %d variables.", warmstarted_count)

  if vocab_info_not_used:
    raise ValueError(
      "You provided the following variables in "
      "var_name_to_vocab_info that were not used: {0}. "
      " Perhaps you misspelled them?  Here is the list of viable variable "
      "names: {1}".format(vocab_info_not_used, grouped_variables.keys()))

  return prev_var_name_used, set(grouped_variables.keys())
示例#31
0
def warm_start(ckpt_to_initialize_from,
               vars_to_warm_start=".*",
               var_name_to_vocab_info=None,
               var_name_to_prev_var_name=None):
    """Warm-starts a model using the given settings.

  If you are using a tf.estimator.Estimator, this will automatically be called
  during training.

  Args:
    ckpt_to_initialize_from: [Required] A string specifying the directory with
      checkpoint file(s) or path to checkpoint from which to warm-start the
      model parameters.
    vars_to_warm_start: [Optional] One of the following:

      - A regular expression (string) that captures which variables to
        warm-start (see tf.compat.v1.get_collection).  This expression will only
        consider variables in the TRAINABLE_VARIABLES collection -- if you need
        to warm-start non_TRAINABLE vars (such as optimizer accumulators or
        batch norm statistics), please use the below option.
      - A list of Variables to warm-start.  If you do not have access to the
        `Variable` objects at the call site, please use the below option.
      - A list of strings, each a regex scope provided to
        tf.compat.v1.get_collection with GLOBAL_VARIABLES (please see
        tf.compat.v1.get_collection).  For backwards compatibility reasons,
        this is separate from the single-string argument type.
      - `None`, in which case only variables specified in
        `var_name_to_vocab_info` will be warm-started.

      Defaults to `'.*'`, which warm-starts all variables in the
      TRAINABLE_VARIABLES collection.  Note that this excludes variables such
      as accumulators and moving statistics from batch norm.
    var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
      `tf.estimator.VocabInfo`. The variable names should be "full" variables,
      not the names of the partitions.  If not explicitly provided, the variable
      is assumed to have no (changes to) vocabulary.
    var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
      name of the previously-trained variable in `ckpt_to_initialize_from`. If
      not explicitly provided, the name of the variable is assumed to be same
      between previous checkpoint and current model.  Note that this has no
      effect on the set of variables that is warm-started, and only controls
      name mapping (use `vars_to_warm_start` for controlling what variables to
      warm-start).

  Raises:
    ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo
      configuration for variable names that are not used.  This is to ensure
      a stronger check for variable configuration than relying on users to
      examine the logs.
  """
    if var_name_to_vocab_info is None:
        var_name_to_vocab_info = {}
    if var_name_to_prev_var_name is None:
        var_name_to_prev_var_name = {}
    logging.info("Warm-starting from: %s", (ckpt_to_initialize_from, ))
    grouped_variables = _get_grouped_variables(vars_to_warm_start)

    # Keep track of which var_names in var_name_to_prev_var_name and
    # var_name_to_vocab_info have been used.  Err on the safer side by throwing an
    # exception if any are unused by the end of the loop.  It is easy to misname
    # a variable during this configuration, in which case without this check, we
    # would fail to warm-start silently.
    prev_var_name_used = set()
    vocab_info_used = set()

    # Group the vocabless vars into one call to init_from_checkpoint.
    vocabless_vars = {}
    for var_name, variable in six.iteritems(grouped_variables):
        prev_var_name = var_name_to_prev_var_name.get(var_name)
        if prev_var_name:
            prev_var_name_used.add(var_name)
        vocab_info = var_name_to_vocab_info.get(var_name)
        if vocab_info:
            vocab_info_used.add(var_name)
            logging.debug(
                "Warm-starting variable: {}; current_vocab: {} current_vocab_size: {}"
                " prev_vocab: {} prev_vocab_size: {} current_oov: {} prev_tensor: {}"
                " initializer: {}".format(
                    var_name, vocab_info.new_vocab, vocab_info.new_vocab_size,
                    vocab_info.old_vocab,
                    (vocab_info.old_vocab_size
                     if vocab_info.old_vocab_size > 0 else "All"),
                    vocab_info.num_oov_buckets, prev_var_name or "Unchanged",
                    vocab_info.backup_initializer or "zero-initialized"))
            _warm_start_var_with_vocab(
                variable,
                current_vocab_path=vocab_info.new_vocab,
                current_vocab_size=vocab_info.new_vocab_size,
                prev_ckpt=ckpt_to_initialize_from,
                prev_vocab_path=vocab_info.old_vocab,
                previous_vocab_size=vocab_info.old_vocab_size,
                current_oov_buckets=vocab_info.num_oov_buckets,
                prev_tensor_name=prev_var_name,
                initializer=vocab_info.backup_initializer,
                axis=vocab_info.axis)
        else:
            # For the special value of vars_to_warm_start = None,
            # we only warm-start variables with explicitly specified vocabularies.
            if vars_to_warm_start:
                logging.debug(
                    "Warm-starting variable: {}; prev_var_name: {}".format(
                        var_name, prev_var_name or "Unchanged"))
                # Because we use a default empty list in grouped_variables, single
                # unpartitioned variables will be lists here, which we rectify in order
                # for init_from_checkpoint logic to work correctly.
                if len(variable) == 1:
                    variable = variable[0]
                prev_tensor_name, var = _get_var_info(variable, prev_var_name)
                vocabless_vars[prev_tensor_name] = var

    checkpoint_utils.init_from_checkpoint(ckpt_to_initialize_from,
                                          vocabless_vars)
    prev_var_name_not_used = set(
        var_name_to_prev_var_name.keys()) - prev_var_name_used
    vocab_info_not_used = set(var_name_to_vocab_info.keys()) - vocab_info_used

    if prev_var_name_not_used:
        raise ValueError(
            "You provided the following variables in "
            "var_name_to_prev_var_name that were not used: "
            "{0}.  Perhaps you misspelled them?  Here is the list of viable "
            "variable names: {1}".format(prev_var_name_not_used,
                                         grouped_variables.keys()))
    if vocab_info_not_used:
        raise ValueError(
            "You provided the following variables in "
            "var_name_to_vocab_info that were not used: {0}. "
            " Perhaps you misspelled them?  Here is the list of viable variable "
            "names: {1}".format(vocab_info_not_used, grouped_variables.keys()))
示例#32
0
def _warmstart_var(var, prev_ckpt, prev_tensor_name=None):
    """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.

  Args:
    var: Current graph's variable that needs to be warm-started (initialized).
      Can be either of the following:
      (i) `Variable`
      (ii) `ResourceVariable`
      (iii) `PartitionedVariable`
      (iv) list of `Variable` and/or `PartitionedVariable`: The list may
        contain one or more variables that has been sharded.  For example:
        [Variable('a/part_0'), Variable('b/part_0'), Variable('a/part_1'),
         PartitionedVariable([Variable('c/part_0'), Variable('c/part_1')])]
        where we have three whole Variables represented ('a', 'b', and 'c').
    prev_ckpt: A string specifying the directory with checkpoint file(s) or path
      to checkpoint. The given checkpoint must have tensor with name
      `prev_tensor_name` (if not None) or tensor with name same as given `var`.
    prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
      None, we lookup tensor with same name as given `var`.

  Raises:
    ValueError: If prev_tensor_name is not None, but the given var represents
      more than one Variable.
    TypeError: If var is not one of the allowed types.
  """
    if _is_variable(var):
        current_var_name = _infer_var_name([var])
    elif isinstance(var, variables.PartitionedVariable):
        current_var_name = _infer_var_name([var])
        var = var._get_variable_list()  # pylint: disable=protected-access
    elif (isinstance(var, list) and all(
            _is_variable(v) or isinstance(v, variables.PartitionedVariable)
            for v in var)):
        # Convert length-1 lists of vars to single tf.Variables.  This ensures that
        # checkpoint_utils.init_from_checkpoint() doesn't incorrectly assume
        # slice info is present.
        if len(var) == 1:
            current_var_name = _infer_var_name(var)
            var = var[0]
        else:
            # If we have multiple elements in var, we cannot assume they all
            # represent the same Variable.
            name_to_var_dict = saver.BaseSaverBuilder.OpListToDict(
                var, convert_variable_to_tensor=False)
            if prev_tensor_name:
                # Providing a prev_tensor_name is only viable if var representes a
                # single Variable.
                if len(name_to_var_dict) > 1:
                    raise ValueError(
                        "var represented more than one Variable, but "
                        "prev_tensor_name was provided.")
                checkpoint_utils.init_from_checkpoint(prev_ckpt,
                                                      {prev_tensor_name: var})
            else:
                # OpListToDict gives us roughly what we need, but
                # the values in the dict may be PartitionedVariables (which
                # init_from_checkpoint does not expect) that we need to convert to
                # lists.
                name_to_var_dict_fixed = {}
                for name, var in six.iteritems(name_to_var_dict):
                    if isinstance(var, variables.PartitionedVariable):
                        name_to_var_dict_fixed[name] = var._get_variable_list()  # pylint: disable=protected-access
                    else:
                        name_to_var_dict_fixed[name] = var
                checkpoint_utils.init_from_checkpoint(prev_ckpt,
                                                      name_to_var_dict_fixed)
            return
    else:
        raise TypeError(
            "var MUST be one of the following: a Variable, PartitionedVariable, or "
            "list of Variable's and/or PartitionedVariable's, but is {}".
            format(type(var)))
    if not prev_tensor_name:
        # Assume tensor name remains the same.
        prev_tensor_name = current_var_name
    checkpoint_utils.init_from_checkpoint(prev_ckpt, {prev_tensor_name: var})
def warm_start(ckpt_to_initialize_from,
               vars_to_warm_start=".*",
               var_name_to_vocab_info=None,
               var_name_to_prev_var_name=None):
  """Warm-starts a model using the given settings.

  If you are using a tf.estimator.Estimator, this will automatically be called
  during training.

  Args:
    ckpt_to_initialize_from: [Required] A string specifying the directory with
      checkpoint file(s) or path to checkpoint from which to warm-start the
      model parameters.
    vars_to_warm_start: [Optional] One of the following:

      - A regular expression (string) that captures which variables to
        warm-start (see tf.get_collection).  This expression will only consider
        variables in the TRAINABLE_VARIABLES collection -- if you need to
        warm-start non_TRAINABLE vars (such as optimizer accumulators or batch
        norm statistics), please use the below option.
      - A list of Variables to warm-start.  If you do not have access to the
        `Variable` objects at the call site, please use the below option.
      - A list of strings, each a regex scope provided to tf.get_collection with
        GLOBAL_VARIABLES (please see tf.get_collection).  For backwards
        compatibility reasons, this is separate from the single-string argument
        type.
      - `None`, in which case only variables specified in
        `var_name_to_vocab_info` will be warm-started.

      Defaults to `'.*'`, which warm-starts all variables in the
      TRAINABLE_VARIABLES collection.  Note that this excludes variables such
      as accumulators and moving statistics from batch norm.
    var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
      `tf.estimator.VocabInfo`. The variable names should be "full" variables,
      not the names of the partitions.  If not explicitly provided, the variable
      is assumed to have no (changes to) vocabulary.
    var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
      name of the previously-trained variable in `ckpt_to_initialize_from`. If
      not explicitly provided, the name of the variable is assumed to be same
      between previous checkpoint and current model.  Note that this has no
      effect on the set of variables that is warm-started, and only controls
      name mapping (use `vars_to_warm_start` for controlling what variables to
      warm-start).
  Raises:
    ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo
      configuration for variable names that are not used.  This is to ensure
      a stronger check for variable configuration than relying on users to
      examine the logs.
  """
  if var_name_to_vocab_info is None:
    var_name_to_vocab_info = {}
  if var_name_to_prev_var_name is None:
    var_name_to_prev_var_name = {}
  logging.info("Warm-starting from: %s", (ckpt_to_initialize_from,))
  grouped_variables = _get_grouped_variables(vars_to_warm_start)

  # Keep track of which var_names in var_name_to_prev_var_name and
  # var_name_to_vocab_info have been used.  Err on the safer side by throwing an
  # exception if any are unused by the end of the loop.  It is easy to misname
  # a variable during this configuration, in which case without this check, we
  # would fail to warm-start silently.
  prev_var_name_used = set()
  vocab_info_used = set()

  # Group the vocabless vars into one call to init_from_checkpoint.
  vocabless_vars = {}
  for var_name, variable in six.iteritems(grouped_variables):
    prev_var_name = var_name_to_prev_var_name.get(var_name)
    if prev_var_name:
      prev_var_name_used.add(var_name)
    vocab_info = var_name_to_vocab_info.get(var_name)
    if vocab_info:
      vocab_info_used.add(var_name)
      logging.info(
          "Warm-starting variable: {}; current_vocab: {} current_vocab_size: {}"
          " prev_vocab: {} prev_vocab_size: {} current_oov: {} prev_tensor: {}"
          " initializer: {}".format(
              var_name,
              vocab_info.new_vocab,
              vocab_info.new_vocab_size,
              vocab_info.old_vocab,
              (vocab_info.old_vocab_size if vocab_info.old_vocab_size > 0
               else "All"),
              vocab_info.num_oov_buckets,
              prev_var_name or "Unchanged",
              vocab_info.backup_initializer or "zero-initialized"))
      _warm_start_var_with_vocab(
          variable,
          current_vocab_path=vocab_info.new_vocab,
          current_vocab_size=vocab_info.new_vocab_size,
          prev_ckpt=ckpt_to_initialize_from,
          prev_vocab_path=vocab_info.old_vocab,
          previous_vocab_size=vocab_info.old_vocab_size,
          current_oov_buckets=vocab_info.num_oov_buckets,
          prev_tensor_name=prev_var_name,
          initializer=vocab_info.backup_initializer,
          axis=vocab_info.axis)
    else:
      # For the special value of vars_to_warm_start = None,
      # we only warm-start variables with explicitly specified vocabularies.
      if vars_to_warm_start:
        logging.info("Warm-starting variable: {}; prev_var_name: {}".format(
            var_name, prev_var_name or "Unchanged"))
        # Because we use a default empty list in grouped_variables, single
        # unpartitioned variables will be lists here, which we rectify in order
        # for init_from_checkpoint logic to work correctly.
        if len(variable) == 1:
          variable = variable[0]
        prev_tensor_name, var = _get_var_info(variable, prev_var_name)
        vocabless_vars[prev_tensor_name] = var

  checkpoint_utils.init_from_checkpoint(ckpt_to_initialize_from, vocabless_vars)
  prev_var_name_not_used = set(
      var_name_to_prev_var_name.keys()) - prev_var_name_used
  vocab_info_not_used = set(var_name_to_vocab_info.keys()) - vocab_info_used

  if prev_var_name_not_used:
    raise ValueError(
        "You provided the following variables in "
        "var_name_to_prev_var_name that were not used: "
        "{0}.  Perhaps you misspelled them?  Here is the list of viable "
        "variable names: {1}".format(prev_var_name_not_used,
                                     grouped_variables.keys()))
  if vocab_info_not_used:
    raise ValueError(
        "You provided the following variables in "
        "var_name_to_vocab_info that were not used: {0}. "
        " Perhaps you misspelled them?  Here is the list of viable variable "
        "names: {1}".format(vocab_info_not_used, grouped_variables.keys()))
示例#34
0
    def testInitFromPartitionVar(self):
        checkpoint_dir = self.get_temp_dir()
        with self.test_session() as session:
            v1 = _create_partition_checkpoints(session, checkpoint_dir)

        # New graph and session.
        with ops.Graph().as_default() as g:
            with self.session(graph=g) as session:
                with variable_scope.variable_scope("some_scope"):
                    my1 = variable_scope.get_variable(
                        name="my1",
                        shape=[100, 100],
                        initializer=init_ops.zeros_initializer(),
                        partitioner=partitioned_variables.
                        min_max_variable_partitioner(max_partitions=5,
                                                     axis=0,
                                                     min_slice_size=8 << 10))
                    my1_var_list = my1._get_variable_list()
                # Create another variable with different partitions than the variable in
                # the checkpoint.
                with variable_scope.variable_scope("some_other_scope"):
                    my2 = variable_scope.get_variable(
                        name="var1",
                        shape=[100, 100],
                        initializer=init_ops.zeros_initializer(),
                        partitioner=partitioned_variables.
                        min_max_variable_partitioner(max_partitions=5,
                                                     axis=0,
                                                     min_slice_size=16 << 10))
                    my2_var_list = my2._get_variable_list()

                checkpoint_utils.init_from_checkpoint(
                    checkpoint_dir, {
                        "scope/var1": "some_scope/my1",
                        "scope/": "some_other_scope/"
                    })

                session.run(variables.global_variables_initializer())
                my1_values = session.run(my1_var_list)
                self.assertAllEqual(my1_values, v1)
                my2_values = session.run(my2_var_list)
                # Verify we created different number of partitions.
                self.assertNotEquals(len(my2_values), len(v1))
                # Verify the values were correctly initialized inspite of different
                # partitions.
                full_my2_values = np.concatenate(my2_values, axis=0)
                full_v1_values = np.concatenate(v1, axis=0)
                self.assertAllEqual(full_my2_values, full_v1_values)

        # New graph and session.
        with ops.Graph().as_default() as g:
            with self.session(graph=g) as session:
                with variable_scope.variable_scope("some_scope"):
                    my1 = variable_scope.get_variable(
                        name="my1",
                        shape=[100, 100],
                        initializer=init_ops.truncated_normal_initializer(0.5),
                        partitioner=partitioned_variables.
                        min_max_variable_partitioner(max_partitions=5,
                                                     axis=0,
                                                     min_slice_size=8 << 10))
                    my1_var_list = my1._get_variable_list()

                checkpoint_utils.init_from_checkpoint(
                    checkpoint_dir, {
                        "scope/var1": my1_var_list,
                    })

                session.run(variables.global_variables_initializer())
                my1_values = session.run(my1_var_list)
                self.assertAllEqual(my1_values, v1)
示例#35
0
    def testInitFromCheckpointMissing(self):
        checkpoint_dir = self.get_temp_dir()
        with self.test_session() as session:
            _, _, _, _ = _create_checkpoints(session, checkpoint_dir)

        # New graph and session.
        with ops.Graph().as_default() as g:
            with self.session(graph=g) as session:
                with variable_scope.variable_scope("some_scope"):
                    _ = variable_scope.get_variable("my1", [10, 10])
                    _ = variable_scope.get_variable(
                        "my2", [1, 10],
                        dtype=dtypes.int64,
                        initializer=init_ops.zeros_initializer())

                # No directory.
                with self.assertRaises(errors_impl.OpError):
                    checkpoint_utils.init_from_checkpoint(
                        "no_dir", {"var1": "some_scope/my1"})

                # No variable in checkpoint.
                with self.assertRaises(ValueError):
                    checkpoint_utils.init_from_checkpoint(
                        checkpoint_dir, {"no_var": "some_scope/my1"})

                # No variable in the graph.
                with self.assertRaises(ValueError):
                    checkpoint_utils.init_from_checkpoint(
                        checkpoint_dir, {"var3": "some_scope/no_var"})

                # Shape mismatch.
                with self.assertRaises(ValueError):
                    checkpoint_utils.init_from_checkpoint(
                        checkpoint_dir, {"var1": "some_scope/my1"})

                # Variable 'my1' and 'my2' are missing in given checkpoint scope.
                with self.assertRaises(ValueError):
                    checkpoint_utils.init_from_checkpoint(
                        checkpoint_dir, {"useful_scope/": "some_scope/"})

                # Mapping is not to scope name.
                with self.assertRaises(ValueError):
                    checkpoint_utils.init_from_checkpoint(
                        checkpoint_dir, {"useful_scope": "some_scope/"})