def testLoadExistingVariablesDifferentShapeAllowReshape(self):
    model_dir = tempfile.mkdtemp(prefix=os.path.join(
        self.get_temp_dir(),
        'load_existing_variables_different_shape_allow_reshape'))
    if gfile.Exists(model_dir):
      gfile.DeleteRecursively(model_dir)

    init_value0 = [[10.0, 11.0]]
    init_value1 = 20.0
    var_names_to_values = {'v0': init_value0, 'v1': init_value1}

    with self.test_session() as sess:
      model_path = self.create_checkpoint_from_values(var_names_to_values,
                                                      model_dir)
      var0 = variables_lib2.variable('my_var0', shape=[2, 1])
      var1 = variables_lib2.variable('my_var1', shape=[])

      vars_to_restore = {'v0': var0, 'v1': var1}
      init_fn = variables_lib2.assign_from_checkpoint_fn(
          model_path, vars_to_restore, reshape_variables=True)

      # Initialize the variables.
      sess.run(variables_lib.global_variables_initializer())

      # Perform the assignment.
      init_fn(sess)

      # Request and test the variable values:
      self.assertAllEqual(np.transpose(np.array(init_value0)), var0.eval())
      self.assertEqual(init_value1, var1.eval())
  def testMissingVariablesDict(self):
    model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
                                                     'missing_variables_dict'))
    if gfile.Exists(model_dir):
      gfile.DeleteRecursively(model_dir)

    init_value0 = 10.0
    init_value1 = 20.0
    var_names_to_values = {'v0': init_value0, 'v1': init_value1}

    with self.test_session() as sess:
      model_path = self.create_checkpoint_from_values(var_names_to_values,
                                                      model_dir)
      var0 = variables_lib2.variable('my_var0', shape=[])
      var1 = variables_lib2.variable('my_var1', shape=[])
      var2 = variables_lib2.variable('my_var2', shape=[])

      vars_to_restore = {'v0': var0, 'v1': var1, 'v2': var2}
      init_fn = variables_lib2.assign_from_checkpoint_fn(
          model_path, vars_to_restore, ignore_missing_vars=True)

      # Initialize the variables.
      sess.run(variables_lib.global_variables_initializer())

      # Perform the assignment.
      init_fn(sess)

      # Request and test the variable values:
      self.assertEqual(init_value0, var0.eval())
      self.assertEqual(init_value1, var1.eval())
  def testLoadExistingVariablesDifferentShapeDefaultDoesNotAllowReshape(self):
    model_dir = tempfile.mkdtemp(prefix=os.path.join(
        self.get_temp_dir(), 'load_existing_vars_no_reshape'))
    if gfile.Exists(model_dir):
      gfile.DeleteRecursively(model_dir)

    init_value0 = [[10.0, 11.0]]
    init_value1 = 20.0
    var_names_to_values = {'v0': init_value0, 'v1': init_value1}

    with self.test_session() as sess:
      model_path = self.create_checkpoint_from_values(var_names_to_values,
                                                      model_dir)
      var0 = variables_lib2.variable('my_var0', shape=[2, 1])
      var1 = variables_lib2.variable('my_var1', shape=[])

      vars_to_restore = {'v0': var0, 'v1': var1}
      init_fn = variables_lib2.assign_from_checkpoint_fn(model_path,
                                                         vars_to_restore)

      # Initialize the variables.
      sess.run(variables_lib.global_variables_initializer())

      # Perform the assignment.
      with self.assertRaises(errors_impl.InvalidArgumentError):
        init_fn(sess)
  def testNotFoundError(self):
    model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
                                                     'not_found_error'))
    if gfile.Exists(model_dir):
      gfile.DeleteRecursively(model_dir)

    init_value0 = 10.0
    init_value1 = 20.0
    var_names_to_values = {'v0': init_value0, 'v1': init_value1}

    with self.test_session() as sess:
      model_path = self.create_checkpoint_from_values(var_names_to_values,
                                                      model_dir)
      var0 = variables_lib2.variable('my_var0', shape=[])
      var1 = variables_lib2.variable('my_var1', shape=[])
      var2 = variables_lib2.variable('my_var2', shape=[])

      vars_to_restore = {'v0': var0, 'v1': var1, 'v2': var2}
      init_fn = variables_lib2.assign_from_checkpoint_fn(model_path,
                                                         vars_to_restore)

      # Initialize the variables.
      sess.run(variables_lib.global_variables_initializer())

      # Perform the assignment.
      with self.assertRaises(errors_impl.NotFoundError):
        init_fn(sess)
Beispiel #5
0
  def testTrainWithInitFromCheckpoint(self):
    logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs1/')
    logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs2/')

    if gfile.Exists(logdir1):  # For running on jenkins.
      gfile.DeleteRecursively(logdir1)
    if gfile.Exists(logdir2):  # For running on jenkins.
      gfile.DeleteRecursively(logdir2)

    # First, train the model one step (make sure the error is high).
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      train_op = self.create_train_op()
      saver = saver_lib.Saver()
      loss = training.train(
          train_op,
          logdir1,
          hooks=[
              basic_session_run_hooks.CheckpointSaverHook(
                  logdir1, save_steps=1, saver=saver),
              basic_session_run_hooks.StopAtStepHook(num_steps=1),
          ],
          save_checkpoint_secs=None)
      self.assertGreater(loss, .5)

    # Next, train the model to convergence.
    with ops.Graph().as_default():
      random_seed.set_random_seed(1)
      train_op = self.create_train_op()
      saver = saver_lib.Saver()
      loss = training.train(
          train_op,
          logdir1,
          hooks=[
              basic_session_run_hooks.CheckpointSaverHook(
                  logdir1, save_steps=1, saver=saver),
              basic_session_run_hooks.StopAtStepHook(num_steps=300),
          ],
          save_checkpoint_secs=None)
      self.assertIsNotNone(loss)
      self.assertLess(loss, .02)

    # Finally, advance the model a single step and validate that the loss is
    # still low.
    with ops.Graph().as_default():
      random_seed.set_random_seed(2)
      train_op = self.create_train_op()

      model_variables = variables_lib2.global_variables()
      model_path = os.path.join(logdir1, 'model.ckpt-300')

      assign_fn = variables_lib.assign_from_checkpoint_fn(model_path,
                                                          model_variables)

      def init_fn(_, session):
        assign_fn(session)

      loss = training.train(
          train_op,
          logdir2,
          scaffold=monitored_session.Scaffold(init_fn=init_fn),
          hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)])

      self.assertIsNotNone(loss)
      self.assertLess(loss, .02)