Esempio n. 1
0
 def test_save_steps_saves_periodically(self):
     with self.graph.as_default():
         monitor = learn.monitors.CheckpointSaver(self.model_dir,
                                                  save_steps=2,
                                                  scaffold=self.scaffold)
         monitor.begin()
         self.scaffold.finalize()
         with session_lib.Session() as sess:
             sess.run(self.scaffold.init_op)
             self._run(monitor, 1, self.train_op, sess)
             self._run(monitor, 2, self.train_op, sess)
             # Not saved
             self.assertEqual(
                 1,
                 checkpoint_utils.load_variable(self.model_dir,
                                                self.global_step.name))
             self._run(monitor, 3, self.train_op, sess)
             # saved
             self.assertEqual(
                 3,
                 checkpoint_utils.load_variable(self.model_dir,
                                                self.global_step.name))
             self._run(monitor, 4, self.train_op, sess)
             # Not saved
             self.assertEqual(
                 3,
                 checkpoint_utils.load_variable(self.model_dir,
                                                self.global_step.name))
             self._run(monitor, 5, self.train_op, sess)
             # saved
             self.assertEqual(
                 5,
                 checkpoint_utils.load_variable(self.model_dir,
                                                self.global_step.name))
Esempio n. 2
0
    def test_train_max_steps_is_not_incremental(self):
        with ops.Graph().as_default() as g, self.session(g):
            with ops.control_dependencies(self._build_inference_graph()):
                train_op = state_ops.assign_add(
                    variables_lib.get_global_step(), 1)
            learn.graph_actions.train(g,
                                      output_dir=self._output_dir,
                                      train_op=train_op,
                                      loss_op=constant_op.constant(2.0),
                                      max_steps=10)
            step = checkpoint_utils.load_variable(
                self._output_dir,
                variables_lib.get_global_step().name)
            self.assertEqual(10, step)

        with ops.Graph().as_default() as g, self.session(g):
            with ops.control_dependencies(self._build_inference_graph()):
                train_op = state_ops.assign_add(
                    variables_lib.get_global_step(), 1)
            learn.graph_actions.train(g,
                                      output_dir=self._output_dir,
                                      train_op=train_op,
                                      loss_op=constant_op.constant(2.0),
                                      max_steps=15)
            step = checkpoint_utils.load_variable(
                self._output_dir,
                variables_lib.get_global_step().name)
            self.assertEqual(15, step)
 def testGetTensor(self):
   checkpoint_dir = self.get_temp_dir()
   with self.cached_session() as session:
     v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
   self.assertAllEqual(
       checkpoint_utils.load_variable(checkpoint_dir, "var1"), v1)
   self.assertAllEqual(
       checkpoint_utils.load_variable(checkpoint_dir, "var2"), v2)
   self.assertAllEqual(
       checkpoint_utils.load_variable(checkpoint_dir, "var3"), v3)
   self.assertAllEqual(
       checkpoint_utils.load_variable(checkpoint_dir, "useful_scope/var4"), v4)
Esempio n. 4
0
def print_tensors_in_checkpoint_file(file_name, tensor_name):
    """Prints tensors in a checkpoint file.

  If no `tensor_name` is provided, prints the tensor names and shapes
  in the checkpoint file.

  If `tensor_name` is provided, prints the content of the tensor.

  Args:
    file_name: Name of the checkpoint file.
    tensor_name: Name of the tensor in the checkpoint file to print.
  """
    try:
        if not tensor_name:
            variables = checkpoint_utils.list_variables(file_name)
            for name, shape in variables:
                print("%s\t%s" % (name, str(shape)))
        else:
            print("tensor_name: ", tensor_name)
            print(checkpoint_utils.load_variable(file_name, tensor_name))
    except Exception as e:  # pylint: disable=broad-except
        print(str(e))
        if "corrupted compressed block contents" in str(e):
            print("It's likely that your checkpoint file has been compressed "
                  "with SNAPPY.")
 def testNoTensor(self):
   checkpoint_dir = self.get_temp_dir()
   with self.cached_session() as session:
     _, _, _, _ = _create_checkpoints(session, checkpoint_dir)
   with self.assertRaises(errors_impl.OpError):
     self.assertAllEqual(
         checkpoint_utils.load_variable(checkpoint_dir, "var5"), [])
Esempio n. 6
0
 def test_save_saves_at_end(self):
     with self.graph.as_default():
         monitor = learn.monitors.CheckpointSaver(self.model_dir,
                                                  save_secs=2,
                                                  scaffold=self.scaffold)
         monitor.begin()
         self.scaffold.finalize()
         with session_lib.Session() as sess:
             sess.run(self.scaffold.init_op)
             self._run(monitor, 1, self.train_op, sess)
             self._run(monitor, 2, self.train_op, sess)
             monitor.end(sess)
             self.assertEqual(
                 2,
                 checkpoint_utils.load_variable(self.model_dir,
                                                self.global_step.name))
 def covariances(self):
     """Returns the covariances."""
     return checkpoint_utils.load_variable(
         self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_COVS_VARIABLE)
 def clusters(self):
     """Returns cluster centers."""
     clusters = checkpoint_utils.load_variable(
         self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_VARIABLE)
     return np.squeeze(clusters, 1)
 def weights(self):
     """Returns the cluster weights."""
     return checkpoint_utils.load_variable(
         self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_WEIGHT)
 def testNoCheckpoints(self):
   checkpoint_dir = self.get_temp_dir() + "/no_checkpoints"
   with self.assertRaises(errors_impl.OpError):
     self.assertAllEqual(
         checkpoint_utils.load_variable(checkpoint_dir, "var1"), [])