コード例 #1
0
    def testEvaluateWithEvalFeedDict(self):
        # Create a checkpoint.
        checkpoint_dir = os.path.join(self.get_temp_dir(),
                                      'evaluate_with_eval_feed_dict')
        self._train_model(checkpoint_dir, num_steps=1)

        # We need a variable that the saver will try to restore.
        variables.get_or_create_global_step()

        # Create a variable and an eval op that increments it with a placeholder.
        my_var = variables.local_variable(0.0, name='my_var')
        increment = array_ops.placeholder(dtype=dtypes.float32)
        eval_ops = state_ops.assign_add(my_var, increment)

        increment_value = 3
        num_evals = 5
        expected_value = increment_value * num_evals
        final_values = evaluation.evaluate_repeatedly(
            checkpoint_dir=checkpoint_dir,
            eval_ops=eval_ops,
            feed_dict={increment: 3},
            final_ops={'my_var': array_ops.identity(my_var)},
            hooks=[
                evaluation.StopAfterNEvalsHook(num_evals),
            ],
            max_number_of_evaluations=1)
        self.assertEqual(final_values['my_var'], expected_value)
コード例 #2
0
    def testEvaluationLoopTimeout(self):
        checkpoint_dir = os.path.join(self.get_temp_dir(),
                                      'evaluation_loop_timeout')
        if not gfile.Exists(checkpoint_dir):
            gfile.MakeDirs(checkpoint_dir)

        # We need a variable that the saver will try to restore.
        variables.get_or_create_global_step()

        # Run with placeholders. If we actually try to evaluate this, we'd fail
        # since we're not using a feed_dict.
        cant_run_op = array_ops.placeholder(dtype=dtypes.float32)

        start = time.time()
        final_values = evaluation.evaluate_repeatedly(
            checkpoint_dir=checkpoint_dir,
            eval_ops=cant_run_op,
            hooks=[evaluation.StopAfterNEvalsHook(10)],
            timeout=6)
        end = time.time()
        self.assertFalse(final_values)

        # Assert that we've waited for the duration of the timeout (minus the sleep
        # time).
        self.assertGreater(end - start, 5.0)

        # Then the timeout kicked in and stops the loop.
        self.assertLess(end - start, 7)
コード例 #3
0
  def testNoneGlobalStep(self):
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

      tf_predictions = BatchNormClassifier(tf_inputs)
      loss_ops.log_loss(tf_predictions, tf_labels)
      total_loss = loss_ops.get_total_loss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)

      train_op = learning.create_train_op(
          total_loss, optimizer, global_step=None)

      global_step = variables_lib2.get_or_create_global_step()

      with session.Session() as sess:
        # Initialize all variables
        sess.run(variables_lib.global_variables_initializer())

        for _ in range(10):
          sess.run([train_op])
        global_step = global_step.eval()
        # Since train_op don't use global_step it shouldn't change.
        self.assertAllClose(global_step, 0)
コード例 #4
0
    def testReturnsSingleCheckpointIfOneShardedCheckpoint(self):
        checkpoint_dir = os.path.join(self.get_temp_dir(),
                                      'one_checkpoint_found_sharded')
        if not gfile.Exists(checkpoint_dir):
            gfile.MakeDirs(checkpoint_dir)

        global_step = variables.get_or_create_global_step()

        # This will result in 3 different checkpoint shard files.
        with ops.device('/cpu:0'):
            variables_lib.Variable(10, name='v0')
        with ops.device('/cpu:1'):
            variables_lib.Variable(20, name='v1')

        saver = saver_lib.Saver(sharded=True)

        with session_lib.Session(target='',
                                 config=config_pb2.ConfigProto(
                                     device_count={'CPU': 2})) as session:

            session.run(variables_lib.global_variables_initializer())
            save_path = os.path.join(checkpoint_dir, 'model.ckpt')
            saver.save(session, save_path, global_step=global_step)

        num_found = 0
        for _ in evaluation.checkpoints_iterator(checkpoint_dir, timeout=0):
            num_found += 1
        self.assertEqual(num_found, 1)
コード例 #5
0
    def testGlobalStepNotIncrementedWhenSetToNone(self):
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            tf_inputs = constant_op.constant(self._inputs,
                                             dtype=dtypes.float32)
            tf_labels = constant_op.constant(self._labels,
                                             dtype=dtypes.float32)

            tf_predictions = batchnorm_classifier(tf_inputs)
            loss = losses.log_loss(tf_labels, tf_predictions)
            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)
            train_op = training.create_train_op(loss,
                                                optimizer,
                                                global_step=None)

            global_step = variables_lib.get_or_create_global_step()

            with self.cached_session() as session:
                # Initialize all variables
                session.run(variables_lib2.global_variables_initializer())

                for _ in range(10):
                    session.run(train_op)

                # Since train_op don't use global_step it shouldn't change.
                self.assertAllClose(global_step.eval(), 0)
 def test_invalid_graph(self):
   # Create inputs.
   model_dir = tempfile.mkdtemp()
   hook = trainer_hooks.FeatureImportanceSummarySaver(model_dir)
   with ops.Graph().as_default():
     # Begin won't be able to find the required tensors in the graph.
     _ = variables.get_or_create_global_step()
     with self.assertRaises(KeyError):
       hook.begin()
コード例 #7
0
    def setUp(self):
        super(EvaluationTest, self).setUp()

        num_classes = 8
        batch_size = 16
        inputs, labels = GenerateTestData(num_classes, batch_size)
        self._expected_accuracy = GroundTruthAccuracy(inputs, labels,
                                                      batch_size)

        self._global_step = variables_lib.get_or_create_global_step()
        self._inputs = constant_op.constant(inputs, dtype=dtypes.float32)
        self._labels = constant_op.constant(labels, dtype=dtypes.int64)
        self._predictions, self._scale = TestModel(self._inputs)
コード例 #8
0
    def testReturnsSingleCheckpointIfOneCheckpointFound(self):
        checkpoint_dir = os.path.join(self.get_temp_dir(),
                                      'one_checkpoint_found')
        if not gfile.Exists(checkpoint_dir):
            gfile.MakeDirs(checkpoint_dir)

        global_step = variables.get_or_create_global_step()
        saver = saver_lib.Saver()  # Saves the global step.

        with self.cached_session() as session:
            session.run(variables_lib.global_variables_initializer())
            save_path = os.path.join(checkpoint_dir, 'model.ckpt')
            saver.save(session, save_path, global_step=global_step)

        num_found = 0
        for _ in evaluation.checkpoints_iterator(checkpoint_dir, timeout=0):
            num_found += 1
        self.assertEqual(num_found, 1)
 def test_run(self):
   # Create inputs.
   model_dir = tempfile.mkdtemp()
   hook = trainer_hooks.FeatureImportanceSummarySaver(model_dir)
   with ops.Graph().as_default(), tf_session.Session() as sess:
     global_step = variables.get_or_create_global_step()
     with ops.name_scope("gbdt"):
       constant_op.constant(["featA", "featB"], name="feature_names")
       constant_op.constant([0, 2], name="feature_usage_counts")
       constant_op.constant([0, 0.8], name="feature_gains")
     # Begin finds tensors in the graph.
     hook.begin()
     sess.run(tf_variables.global_variables_initializer())
     # Run hook in a monitored session.
     train_op = state_ops.assign_add(global_step, 1)
     mon_sess = monitored_session._HookedSession(sess, [hook])
     mon_sess.run(train_op)
     hook.end(sess)
     # Ensure output summary dirs are created.
     self.assertTrue(os.path.exists(os.path.join(model_dir, "featA")))
     self.assertTrue(os.path.exists(os.path.join(model_dir, "featB")))