示例#1
0
    def testReturnsSingleCheckpointIfOneShardedCheckpoint(self):
        checkpoint_dir = tempfile.mkdtemp('one_checkpoint_found_sharded')
        if not gfile.Exists(checkpoint_dir):
            gfile.MakeDirs(checkpoint_dir)

        global_step = variables.get_or_create_global_step()

        # This will result in 3 different checkpoint shard files.
        with ops.device('/cpu:0'):
            variables_lib.Variable(10, name='v0')
        with ops.device('/cpu:1'):
            variables_lib.Variable(20, name='v1')

        saver = saver_lib.Saver(sharded=True)

        with session_lib.Session(
                target='',
                config=config_pb2.ConfigProto(device_count={'CPU': 2})) as session:

            session.run(variables_lib.global_variables_initializer())
            save_path = os.path.join(checkpoint_dir, 'model.ckpt')
            saver.save(session, save_path, global_step=global_step)

        num_found = 0
        for _ in evaluation.checkpoints_iterator(checkpoint_dir, timeout=0):
            num_found += 1
        self.assertEqual(num_found, 1)
示例#2
0
    def testReturnsEmptyIfNoCheckpointsFound(self):
        checkpoint_dir = tempfile.mkdtemp('no_checkpoints_found')

        num_found = 0
        for _ in evaluation.checkpoints_iterator(checkpoint_dir, timeout=0):
            num_found += 1
        self.assertEqual(num_found, 0)
示例#3
0
    def testTimeoutFn(self):
        timeout_fn_calls = [0]

        def timeout_fn():
            timeout_fn_calls[0] += 1
            return timeout_fn_calls[0] > 3

        results = list(
            evaluation.checkpoints_iterator(
                '/non-existent-dir', timeout=0.1, timeout_fn=timeout_fn))
        self.assertEqual([], results)
        self.assertEqual(4, timeout_fn_calls[0])
示例#4
0
    def testReturnsSingleCheckpointIfOneCheckpointFound(self):
        checkpoint_dir = tempfile.mkdtemp('one_checkpoint_found')
        if not gfile.Exists(checkpoint_dir):
            gfile.MakeDirs(checkpoint_dir)

        global_step = variables.get_or_create_global_step()
        saver = saver_lib.Saver()  # Saves the global step.

        with self.cached_session() as session:
            session.run(variables_lib.global_variables_initializer())
            save_path = os.path.join(checkpoint_dir, 'model.ckpt')
            saver.save(session, save_path, global_step=global_step)

        num_found = 0
        for _ in evaluation.checkpoints_iterator(checkpoint_dir, timeout=0):
            num_found += 1
        self.assertEqual(num_found, 1)
示例#5
0
 def testMonitorCheckpointsLoopTimeout(self):
     ret = list(
         evaluation_lib.checkpoints_iterator('/non-existent-dir',
                                             timeout=0))
     self.assertEqual(ret, [])