示例#1
0
  def testInitFromCheckpointWithScopes(self):
    model_dir = tempfile.mkdtemp(prefix=os.path.join(
        self.get_temp_dir(), 'init_from_checkpoint_with_scopes'))

    init_value0 = np.asarray(
        [1.0, 3.0, 9.0], dtype=np.float32).reshape((1, 3, 1))
    init_value1 = np.asarray(
        [2.0, 4.0, 6.0, 8.0], dtype=np.float32).reshape((2, 1, 2))

    var_names_to_values = {'layer0/v0': init_value0, 'layer1/v1': init_value1}

    with self.test_session() as sess:
      model_path = self.create_checkpoint_from_values(var_names_to_values,
                                                      model_dir)
      with variable_scope.variable_scope('my_model/my_layer0'):
        var0 = variables_lib2.variable('my_var0', shape=init_value0.shape)
      with variable_scope.variable_scope('my_model/my_layer1'):
        var1 = variables_lib2.variable('my_var1', shape=init_value1.shape)

      vars_to_restore = {'layer0/v0': var0, 'layer1/v1': var1}
      op, feed_dict = variables_lib2.assign_from_checkpoint(model_path,
                                                            vars_to_restore)

      # Initialize the variables.
      sess.run(variables_lib.global_variables_initializer())

      # Perform the assignment.
      sess.run(op, feed_dict)

      # Request and test the variable values:
      self.assertAllEqual(init_value0, var0.eval())
      self.assertAllEqual(init_value1, var1.eval())
示例#2
0
  def testLoadExistingVariables(self):
    model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
                                                     'load_existing_variables'))

    init_value0 = 10.0
    init_value1 = 20.0
    var_names_to_values = {'v0': init_value0, 'v1': init_value1}

    with self.test_session() as sess:
      model_path = self.create_checkpoint_from_values(var_names_to_values,
                                                      model_dir)
      var0 = variables_lib2.variable('my_var0', shape=[])
      var1 = variables_lib2.variable('my_var1', shape=[])

      vars_to_restore = {'v0': var0, 'v1': var1}
      op, feed_dict = variables_lib2.assign_from_checkpoint(model_path,
                                                            vars_to_restore)

      # Initialize the variables.
      sess.run(variables_lib.global_variables_initializer())

      # Perform the assignment.
      sess.run(op, feed_dict)

      # Request and test the variable values:
      self.assertEqual(init_value0, var0.eval())
      self.assertEqual(init_value1, var1.eval())
示例#3
0
  def testRaisesValueErrorIfAVariableIsntFound(self):
    model_dir = tempfile.mkdtemp(prefix=os.path.join(
        self.get_temp_dir(), 'raises_value_error_if_var_isnt_found'))

    init_value0 = 10.0
    init_value1 = 20.0
    var_names_to_values = {'v0': init_value0, 'v1': init_value1}

    with self.test_session():
      model_path = self.create_checkpoint_from_values(var_names_to_values,
                                                      model_dir)
      var0 = variables_lib2.variable('my_var0', shape=[])
      var1 = variables_lib2.variable('my_var1', shape=[])

      vars_to_restore = {'v0_fake': var0, 'v1': var1}

      with self.assertRaises(ValueError):
        variables_lib2.assign_from_checkpoint(model_path, vars_to_restore)
示例#4
0
    def testTrainWithInitFromCheckpoint(self):
        logdir1 = os.path.join(tempfile.mkdtemp(prefix=self.get_temp_dir()),
                               'tmp_logs1')
        logdir2 = os.path.join(tempfile.mkdtemp(prefix=self.get_temp_dir()),
                               'tmp_logs2')

        # First, train the model one step (make sure the error is high).
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            train_op = self.create_train_op()
            loss = learning.train(train_op, logdir1, number_of_steps=1)
            self.assertGreater(loss, .5)

        # Next, train the model to convergence.
        with ops.Graph().as_default():
            random_seed.set_random_seed(1)
            train_op = self.create_train_op()
            loss = learning.train(train_op,
                                  logdir1,
                                  number_of_steps=300,
                                  log_every_n_steps=10)
            self.assertIsNotNone(loss)
            self.assertLess(loss, .02)

        # Finally, advance the model a single step and validate that the loss is
        # still low.
        with ops.Graph().as_default():
            random_seed.set_random_seed(2)
            train_op = self.create_train_op()

            model_variables = variables_lib.global_variables()
            model_path = os.path.join(logdir1, 'model.ckpt-300')

            init_op = variables_lib.global_variables_initializer()
            op, init_feed_dict = variables_lib2.assign_from_checkpoint(
                model_path, model_variables)

            def InitAssignFn(sess):
                sess.run(op, init_feed_dict)

            loss = learning.train(train_op,
                                  logdir2,
                                  number_of_steps=1,
                                  init_op=init_op,
                                  init_fn=InitAssignFn)

            self.assertIsNotNone(loss)
            self.assertLess(loss, .02)
示例#5
0
  def testTrainWithInitFromCheckpoint(self):
    logdir1 = os.path.join(
        tempfile.mkdtemp(prefix=self.get_temp_dir()), 'tmp_logs1')
    logdir2 = os.path.join(
        tempfile.mkdtemp(prefix=self.get_temp_dir()), 'tmp_logs2')

    # First, train the model one step (make sure the error is high).
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      train_op = self.create_train_op()
      loss = learning.train(train_op, logdir1, number_of_steps=1)
      self.assertGreater(loss, .5)

    # Next, train the model to convergence.
    with ops.Graph().as_default():
      random_seed.set_random_seed(1)
      train_op = self.create_train_op()
      loss = learning.train(
          train_op, logdir1, number_of_steps=300, log_every_n_steps=10)
      self.assertIsNotNone(loss)
      self.assertLess(loss, .02)

    # Finally, advance the model a single step and validate that the loss is
    # still low.
    with ops.Graph().as_default():
      random_seed.set_random_seed(2)
      train_op = self.create_train_op()

      model_variables = variables_lib.global_variables()
      model_path = os.path.join(logdir1, 'model.ckpt-300')

      init_op = variables_lib.global_variables_initializer()
      op, init_feed_dict = variables_lib2.assign_from_checkpoint(
          model_path, model_variables)

      def InitAssignFn(sess):
        sess.run(op, init_feed_dict)

      loss = learning.train(
          train_op,
          logdir2,
          number_of_steps=1,
          init_op=init_op,
          init_fn=InitAssignFn)

      self.assertIsNotNone(loss)
      self.assertLess(loss, .02)